Skip to content

Commit

Permalink
Suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
PietropaoloFrisoni committed Jul 25, 2024
1 parent 2cb2538 commit b8f80fc
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 26 deletions.
47 changes: 28 additions & 19 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def cond(condition, true_fn: Callable, false_fn: Callable = None, elifs=()):
Each branch can receive arguments, but the arguments must be the same for all branches.
Both the arguments and the branches must be JAX-compatible.
If a branch returns one or more variables, every other branch must return the same abstract values.
If a branch returns one or more operators, these will be appended to the QueuingManager.
If a branch returns one or more operators, these will be applied to the circuit.
.. note::
Expand Down Expand Up @@ -392,7 +392,6 @@ def qnode(a, x, y, z):
return cond_func

if qml.capture.enabled():
print("Capture mode for cond")
return _capture_cond(condition, true_fn, false_fn, elifs)

if elifs:
Expand Down Expand Up @@ -469,7 +468,7 @@ def true_branch(args):

def elif_branch(args, elifs_conditions, jaxpr_elifs):
if not jaxpr_elifs:
return false_branch(args)
return None
pred = elifs_conditions[0]
rest_preds = elifs_conditions[1:]
jaxpr_elif = jaxpr_elifs[0]
Expand All @@ -485,9 +484,12 @@ def false_branch(args):

if condition:
return true_branch(args)
if elifs_conditions.size > 0:
return elif_branch(args, elifs_conditions, jaxpr_elifs)
return false_branch(args)

elif_branch_out = (
elif_branch(args, elifs_conditions, jaxpr_elifs) if elifs_conditions.size > 0 else None
)

return false_branch(args) if elif_branch_out is None else elif_branch_out

@cond_prim.def_abstract_eval
def _(*_, jaxpr_true, jaxpr_false, jaxpr_elifs):
Expand All @@ -500,24 +502,33 @@ def validate_abstract_values(
) -> None:
"""Ensure the collected abstract values match the expected ones."""

assert len(outvals) == len(expected_outvals), (
f"Mismatch in number of output variables in {branch_type} branch"
f"{'' if index is None else ' #' + str(index)}: "
f"{len(outvals)} vs {len(expected_outvals)}"
)
for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)):
assert outval == expected_outval, (
f"Mismatch in output abstract values in {branch_type} branch"
f"{'' if index is None else ' #' + str(index)} at position {i}: "
f"{outval} vs {expected_outval}"
if len(outvals) != len(expected_outvals):
raise ValueError(
f"Mismatch in number of output variables in {branch_type} branch"
f"{'' if index is None else ' #' + str(index)}: "
f"{len(outvals)} vs {len(expected_outvals)}"
)

for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)):
if outval != expected_outval:
raise ValueError(
f"Mismatch in output abstract values in {branch_type} branch"
f"{'' if index is None else ' #' + str(index)} at position {i}: "
f"{outval} vs {expected_outval}"
)

outvals_true = jaxpr_true.out_avals

if jaxpr_false is not None:
outvals_false = jaxpr_false.out_avals
validate_abstract_values(outvals_false, outvals_true, "false")

else:
if outvals_true is not None:
raise ValueError(
"The false branch must be provided if the true branch returns any variables"
)

for idx, jaxpr_elif in enumerate(jaxpr_elifs):
outvals_elif = jaxpr_elif.out_avals
validate_abstract_values(outvals_elif, outvals_true, "elif", idx)
Expand All @@ -543,9 +554,7 @@ def new_wrapper(*args, **kwargs):

jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args)
jaxpr_false = (
(jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None)
if false_fn
else None
jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None
)

# We extract each condition (or predicate) from the elifs argument list
Expand Down
26 changes: 19 additions & 7 deletions tests/capture/test_capture_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,19 +180,19 @@ class TestCondReturns:
(
lambda x: (x + 1, x + 2),
lambda x: None,
AssertionError,
ValueError,
r"Mismatch in number of output variables",
),
(
lambda x: (x + 1, x + 2),
lambda x: (x + 1,),
AssertionError,
ValueError,
r"Mismatch in number of output variables",
),
(
lambda x: (x + 1, x + 2),
lambda x: (x + 1, x + 2.0),
AssertionError,
ValueError,
r"Mismatch in output abstract values",
),
],
Expand All @@ -211,7 +211,7 @@ def true_fn(x):
def false_fn(x):
return x + 1

with pytest.raises(AssertionError, match=r"Mismatch in number of output variables"):
with pytest.raises(ValueError, match=r"Mismatch in number of output variables"):
jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1))

def test_validate_output_variable_types(self):
Expand All @@ -223,9 +223,21 @@ def true_fn(x):
def false_fn(x):
return x + 1, x + 2.0

with pytest.raises(AssertionError, match=r"Mismatch in output abstract values"):
with pytest.raises(ValueError, match=r"Mismatch in output abstract values"):
jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1))

def test_validate_no_false_branch_with_return(self):
"""Test no false branch provided with return variables."""

def true_fn(x):
return x + 1, x + 2

with pytest.raises(
ValueError,
match=r"The false branch must be provided if the true branch returns any variables",
):
jax.make_jaxpr(_capture_cond(True, true_fn))(jax.numpy.array(1))

def test_validate_elif_branches(self):
"""Test elif branch mismatches."""

Expand All @@ -245,14 +257,14 @@ def elif_fn3(x):
return x + 1

with pytest.raises(
AssertionError, match=r"Mismatch in output abstract values in elif branch #1"
ValueError, match=r"Mismatch in output abstract values in elif branch #1"
):
jax.make_jaxpr(
_capture_cond(False, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)])
)(jax.numpy.array(1))

with pytest.raises(
AssertionError, match=r"Mismatch in number of output variables in elif branch #0"
ValueError, match=r"Mismatch in number of output variables in elif branch #0"
):
jax.make_jaxpr(_capture_cond(False, true_fn, false_fn, [(True, elif_fn3)]))(
jax.numpy.array(1)
Expand Down

0 comments on commit b8f80fc

Please sign in to comment.