diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index f12d629f162..6be1de601e6 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -121,7 +121,7 @@ def cond(condition, true_fn: Callable, false_fn: Callable = None, elifs=()): Each branch can receive arguments, but the arguments must be the same for all branches. Both the arguments and the branches must be JAX-compatible. If a branch returns one or more variables, every other branch must return the same abstract values. - If a branch returns one or more operators, these will be appended to the QueuingManager. + If a branch returns one or more operators, these will be applied to the circuit. .. note:: @@ -392,7 +392,6 @@ def qnode(a, x, y, z): return cond_func if qml.capture.enabled(): - print("Capture mode for cond") return _capture_cond(condition, true_fn, false_fn, elifs) if elifs: @@ -469,7 +468,7 @@ def true_branch(args): def elif_branch(args, elifs_conditions, jaxpr_elifs): if not jaxpr_elifs: - return false_branch(args) + return None pred = elifs_conditions[0] rest_preds = elifs_conditions[1:] jaxpr_elif = jaxpr_elifs[0] @@ -485,9 +484,12 @@ def false_branch(args): if condition: return true_branch(args) - if elifs_conditions.size > 0: - return elif_branch(args, elifs_conditions, jaxpr_elifs) - return false_branch(args) + + elif_branch_out = ( + elif_branch(args, elifs_conditions, jaxpr_elifs) if elifs_conditions.size > 0 else None + ) + + return false_branch(args) if elif_branch_out is None else elif_branch_out @cond_prim.def_abstract_eval def _(*_, jaxpr_true, jaxpr_false, jaxpr_elifs): @@ -500,24 +502,33 @@ def validate_abstract_values( ) -> None: """Ensure the collected abstract values match the expected ones.""" - assert len(outvals) == len(expected_outvals), ( - f"Mismatch in number of output variables in {branch_type} branch" - f"{'' if index is None else ' #' + str(index)}: " - f"{len(outvals)} vs {len(expected_outvals)}" - ) - for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)): - assert outval == expected_outval, ( - f"Mismatch in output abstract values in {branch_type} branch" - f"{'' if index is None else ' #' + str(index)} at position {i}: " - f"{outval} vs {expected_outval}" + if len(outvals) != len(expected_outvals): + raise ValueError( + f"Mismatch in number of output variables in {branch_type} branch" + f"{'' if index is None else ' #' + str(index)}: " + f"{len(outvals)} vs {len(expected_outvals)}" ) + for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)): + if outval != expected_outval: + raise ValueError( + f"Mismatch in output abstract values in {branch_type} branch" + f"{'' if index is None else ' #' + str(index)} at position {i}: " + f"{outval} vs {expected_outval}" + ) + outvals_true = jaxpr_true.out_avals if jaxpr_false is not None: outvals_false = jaxpr_false.out_avals validate_abstract_values(outvals_false, outvals_true, "false") + else: + if outvals_true is not None: + raise ValueError( + "The false branch must be provided if the true branch returns any variables" + ) + for idx, jaxpr_elif in enumerate(jaxpr_elifs): outvals_elif = jaxpr_elif.out_avals validate_abstract_values(outvals_elif, outvals_true, "elif", idx) @@ -543,9 +554,7 @@ def new_wrapper(*args, **kwargs): jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args) jaxpr_false = ( - (jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None) - if false_fn - else None + jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None ) # We extract each condition (or predicate) from the elifs argument list diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index 934f37c4ae3..6cec5e68246 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -180,19 +180,19 @@ class TestCondReturns: ( lambda x: (x + 1, x + 2), lambda x: None, - AssertionError, + ValueError, r"Mismatch in number of output variables", ), ( lambda x: (x + 1, x + 2), lambda x: (x + 1,), - AssertionError, + ValueError, r"Mismatch in number of output variables", ), ( lambda x: (x + 1, x + 2), lambda x: (x + 1, x + 2.0), - AssertionError, + ValueError, r"Mismatch in output abstract values", ), ], @@ -211,7 +211,7 @@ def true_fn(x): def false_fn(x): return x + 1 - with pytest.raises(AssertionError, match=r"Mismatch in number of output variables"): + with pytest.raises(ValueError, match=r"Mismatch in number of output variables"): jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) def test_validate_output_variable_types(self): @@ -223,9 +223,21 @@ def true_fn(x): def false_fn(x): return x + 1, x + 2.0 - with pytest.raises(AssertionError, match=r"Mismatch in output abstract values"): + with pytest.raises(ValueError, match=r"Mismatch in output abstract values"): jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + def test_validate_no_false_branch_with_return(self): + """Test no false branch provided with return variables.""" + + def true_fn(x): + return x + 1, x + 2 + + with pytest.raises( + ValueError, + match=r"The false branch must be provided if the true branch returns any variables", + ): + jax.make_jaxpr(_capture_cond(True, true_fn))(jax.numpy.array(1)) + def test_validate_elif_branches(self): """Test elif branch mismatches.""" @@ -245,14 +257,14 @@ def elif_fn3(x): return x + 1 with pytest.raises( - AssertionError, match=r"Mismatch in output abstract values in elif branch #1" + ValueError, match=r"Mismatch in output abstract values in elif branch #1" ): jax.make_jaxpr( _capture_cond(False, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)]) )(jax.numpy.array(1)) with pytest.raises( - AssertionError, match=r"Mismatch in number of output variables in elif branch #0" + ValueError, match=r"Mismatch in number of output variables in elif branch #0" ): jax.make_jaxpr(_capture_cond(False, true_fn, false_fn, [(True, elif_fn3)]))( jax.numpy.array(1)