Skip to content

Commit

Permalink
Adding tests to catch errors
Browse files Browse the repository at this point in the history
  • Loading branch information
PietropaoloFrisoni committed Jul 24, 2024
1 parent f9ad42a commit e8fef5d
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 96 deletions.
9 changes: 3 additions & 6 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ def run_jaxpr(jaxpr, *args):

out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)

# If the branch returns an Operator, we append it to the QueuingManager
# so that it is applied to the circuit
for outvar in out:
if isinstance(outvar, Operator):
QueuingManager.append(outvar)
Expand Down Expand Up @@ -502,13 +504,8 @@ def validate_abstract_values(
f"{len(outvals)} vs {len(expected_outvals)}"
)
for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)):
assert isinstance(outval, type(expected_outval)), (
f"Mismatch in output variable types in {branch_type} branch"
f"{'' if index is None else ' #' + str(index)} at position {i}: "
f"{type(outval)} vs {type(expected_outval)}"
)
assert outval == expected_outval, (
f"Mismatch in output variable values in {branch_type} branch"
f"Mismatch in output abstract values in {branch_type} branch"
f"{'' if index is None else ' #' + str(index)} at position {i}: "
f"{outval} vs {expected_outval}"
)
Expand Down
235 changes: 145 additions & 90 deletions tests/capture/test_capture_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
"""
Tests for capturing conditionals into jaxpr.
"""
import numpy as np

# pylint: disable=protected-access
# pylint: disable=redefined-outer-name

import jax.numpy as jnp
import numpy as np
import pytest

import pennylane as qml
from pennylane.ops.op_math.condition import _capture_cond

pytestmark = pytest.mark.jax

Expand All @@ -34,8 +37,9 @@ def enable_disable_plxpr():
qml.capture.disable()


def cond_true_elifs_false(selector, arg):
"""A function with conditional containing true, elifs, and false branches."""
@pytest.fixture
def testing_functions():
"""Returns a set of functions for testing."""

def true_fn(arg):
return 2 * arg
Expand All @@ -55,7 +59,26 @@ def elif_fn4(arg):
def false_fn(arg):
return 3 * arg

return qml.cond(
return true_fn, false_fn, elif_fn1, elif_fn2, elif_fn3, elif_fn4


@pytest.mark.parametrize(
"selector, arg, expected",
[
(1, 10, 20), # True condition
(-1, 10, 9), # Elif condition 1
(-2, 10, 8), # Elif condition 2
(-3, 10, 7), # Elif condition 3
(-4, 10, 6), # Elif condition 4
(0, 10, 30), # False condition
],
)
def test_cond_true_elifs_false(testing_functions, selector, arg, expected):
"""Test the conditional with true, elifs, and false branches."""

true_fn, false_fn, elif_fn1, elif_fn2, elif_fn3, elif_fn4 = testing_functions

result = qml.cond(
selector > 0,
true_fn,
false_fn,
Expand All @@ -66,116 +89,148 @@ def false_fn(arg):
(selector == -4, elif_fn4),
),
)(arg)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"


def cond_true_elifs(selector, arg):
"""A function with conditional containing true and elifs branches."""

def true_fn(arg):
return 2 * arg

def elif_fn1(arg):
return arg - 1
@pytest.mark.parametrize(
"selector, arg, expected",
[
(1, 10, 20), # True condition
(-1, 10, 9), # Elif condition 1
(-2, 10, 8), # Elif condition 2
(-3, 10, ()), # No condition met
],
)
def test_cond_true_elifs(testing_functions, selector, arg, expected):
"""Test the conditional with true and elifs branches."""

def elif_fn2(arg):
return arg - 2
true_fn, _, elif_fn1, elif_fn2, _, _ = testing_functions

return qml.cond(
result = qml.cond(
selector > 0,
true_fn,
elifs=(
(selector == -1, elif_fn1),
(selector == -2, elif_fn2),
),
)(arg)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"


def cond_true_false(selector, arg):
"""A function with conditional containing true and false branches."""
@pytest.mark.parametrize(
"selector, arg, expected",
[
(1, 10, 20), # True condition
(0, 10, 30), # False condition
],
)
def test_cond_true_false(testing_functions, selector, arg, expected):
"""Test the conditional with true and false branches."""

def true_fn(arg):
return 2 * arg
true_fn, false_fn, _, _, _, _ = testing_functions

def false_fn(arg):
return 3 * arg

return qml.cond(
result = qml.cond(
selector > 0,
true_fn,
false_fn,
)(arg)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"


def cond_true(selector, arg):
"""A function with conditional containing only the true branch."""
@pytest.mark.parametrize(
"selector, arg, expected",
[
(1, 10, 20), # True condition
(0, 10, ()), # No condition met
],
)
def test_cond_true(testing_functions, selector, arg, expected):
"""Test the conditional with only the true branch."""

def true_fn(arg):
return 2 * arg
true_fn, _, _, _, _, _ = testing_functions

return qml.cond(
result = qml.cond(
selector > 0,
true_fn,
)(arg)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"


# pylint: disable=no-self-use
class TestCond:
"""Tests for capturing conditional statements."""

@pytest.mark.parametrize(
"selector, arg, expected",
[
(1, 10, 20), # True condition
(-1, 10, 9), # Elif condition 1
(-2, 10, 8), # Elif condition 2
(-3, 10, 7), # Elif condition 3
(-4, 10, 6), # Elif condition 4
(0, 10, 30), # False condition
],
)
def test_cond_true_elifs_false(self, selector, arg, expected):
"""Test the conditional with true, elifs, and false branches."""

result = cond_true_elifs_false(selector, arg)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

@pytest.mark.parametrize(
"selector, arg, expected",
[
(1, 10, 20), # True condition
(-1, 10, 9), # Elif condition 1
(-2, 10, 8), # Elif condition 2
(-3, 10, ()), # No condition met
],
)
def test_cond_true_elifs(self, selector, arg, expected):
"""Test the conditional with true and elifs branches."""

result = cond_true_elifs(selector, arg)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

@pytest.mark.parametrize(
"selector, arg, expected",
[
(1, 10, 20), # True condition
(0, 10, 30), # False condition
],
)
def test_cond_true_false(self, selector, arg, expected):
"""Test the conditional with true and false branches."""

result = cond_true_false(selector, arg)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

@pytest.mark.parametrize(
"selector, arg, expected",
[
(1, 10, 20), # True condition
(0, 10, ()), # No condition met
],
)
def test_cond_true(self, selector, arg, expected):
"""Test the conditional with only the true branch."""

result = cond_true(selector, arg)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"
def test_validate_number_of_output_variables():
"""Test mismatch in number of output variables."""

def true_fn(x):
return x + 1, x + 2

def false_fn(x):
return x + 1

with pytest.raises(AssertionError, match=r"Mismatch in number of output variables"):
jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jnp.array(1))


def test_validate_output_variable_types():
"""Test mismatch in output variable types."""

def true_fn(x):
return x + 1, x + 2

def false_fn(x):
return x + 1, x + 2.0

with pytest.raises(AssertionError, match=r"Mismatch in output abstract values"):
jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jnp.array(1))


def test_validate_elif_branches():
"""Test elif branch mismatches."""

def true_fn(x):
return x + 1, x + 2

def false_fn(x):
return x + 1, x + 2

def elif_fn1(x):
return x + 1, x + 2

def elif_fn2(x):
return x + 1, x + 2.0 # Type mismatch

def elif_fn3(x):
return x + 1 # Length mismatch

with pytest.raises(
AssertionError, match=r"Mismatch in output abstract values in elif branch #1"
):
jax.make_jaxpr(
_capture_cond(False, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)])
)(jnp.array(1))

with pytest.raises(
AssertionError, match=r"Mismatch in number of output variables in elif branch #0"
):
jax.make_jaxpr(_capture_cond(False, true_fn, false_fn, [(True, elif_fn3)]))(jnp.array(1))


@pytest.mark.parametrize(
"true_fn, false_fn, expected_error, match",
[
(
lambda x: (x + 1, x + 2),
lambda x: (x + 1),
AssertionError,
r"Mismatch in number of output variables",
),
(
lambda x: (x + 1, x + 2),
lambda x: (x + 1, x + 2.0),
AssertionError,
r"Mismatch in output abstract values",
),
],
)
def test_validate_mismatches(true_fn, false_fn, expected_error, match):
"""Test mismatch in number and type of output variables."""
with pytest.raises(expected_error, match=match):
jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jnp.array(1))

0 comments on commit e8fef5d

Please sign in to comment.