diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 5a04c3946f0..9cc06aaeadf 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -142,6 +142,29 @@ 0.11917543, 0.08942104, 0.21545687], dtype=float64) ``` +* If the conditional does not include a mid-circuit measurement, then `qml.cond` + will automatically evaluate conditionals using standard Python control flow. + [(#6016)](https://github.com/PennyLaneAI/pennylane/pull/6016) + + This allows `qml.cond` to be used to represent a wider range of conditionals: + + ```python + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev) + def circuit(x): + c = qml.cond(x > 2.7, qml.RX, qml.RZ) + c(x, wires=0) + return qml.probs(wires=0) + ``` + + ```pycon + >>> print(qml.draw(circuit)(3.8)) + 0: ──RX(3.80)─┤ Probs + >>> print(qml.draw(circuit)(0.54)) + 0: ──RZ(0.54)─┤ Probs + ``` + * The `qubit_observable` function is modified to return an ascending wire order for molecular Hamiltonians. [(#5950)](https://github.com/PennyLaneAI/pennylane/pull/5950) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 7df1bcb54f6..60380d3ad01 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -21,6 +21,7 @@ import pennylane as qml from pennylane import QueuingManager from pennylane.compiler import compiler +from pennylane.measurements import MeasurementValue from pennylane.operation import AnyWires, Operation, Operator from pennylane.ops.op_math.symbolicop import SymbolicOp from pennylane.tape import make_qscript @@ -102,7 +103,174 @@ def adjoint(self): return Conditional(self.meas_val, self.base.adjoint()) -def cond(condition, true_fn: Callable, false_fn: Optional[Callable] = None, elifs=()): +class CondCallable: # pylint:disable=too-few-public-methods + """Base class to represent a conditional function with boolean predicates. + + Args: + condition (bool): a conditional expression + true_fn (callable): The function to apply if ``condition`` is ``True`` + false_fn (callable): The function to apply if ``condition`` is ``False`` + elifs (List(Tuple(bool, callable))): A list of (bool, elif_fn) clauses. + + Passing ``false_fn`` and ``elifs`` on initialization + is optional; these functions can be registered post-initialization + via decorators: + + .. code-block:: python + + def f(x): + @qml.cond(x > 0) + def conditional(y): + return y ** 2 + + @conditional.else_if(x < -2) + def conditional(y): + return y + + @conditional.otherwise + def conditional_false_fn(y): + return -y + + return conditional(x + 1) + + >>> [f(0.5), f(-3), f(-0.5)] + [2.25, -2, -0.5] + """ + + def __init__(self, condition, true_fn, false_fn=None, elifs=()): + self.preds = [condition] + self.branch_fns = [true_fn] + self.otherwise_fn = false_fn + + # when working with `qml.capture.enabled()`, + # it's easier to store the original `elifs` argument + self.orig_elifs = elifs + + if false_fn is None and not qml.capture.enabled(): + self.otherwise_fn = lambda *args, **kwargs: None + + if elifs and not qml.capture.enabled(): + elif_preds, elif_fns = list(zip(*elifs)) + self.preds.extend(elif_preds) + self.branch_fns.extend(elif_fns) + + def else_if(self, pred): + """Decorator that allows else-if functions to be registered with a corresponding + boolean predicate. + + Args: + pred (bool): The predicate that will determine if this branch is executed. + + Returns: + callable: decorator that is applied to the else-if function + """ + + def decorator(branch_fn): + self.preds.append(pred) + self.branch_fns.append(branch_fn) + return self + + return decorator + + def otherwise(self, otherwise_fn): + """Decorator that registers the function to be run if all + conditional predicates (including optional) evaluates to ``False``. + + Args: + otherwise_fn (callable): the function to apply if all ``self.preds`` evaluate to ``False`` + """ + self.otherwise_fn = otherwise_fn + return self + + @property + def false_fn(self): + """callable: the function to apply if all ``self.preds`` evaluate to ``False``""" + return self.otherwise_fn + + @property + def true_fn(self): + """callable: the function to apply if all ``self.condition`` evaluate to ``True``""" + return self.branch_fns[0] + + @property + def condition(self): + """bool: the condition that determines if ``self.true_fn`` is applied""" + return self.preds[0] + + @property + def elifs(self): + """(List(Tuple(bool, callable))): a list of (bool, elif_fn) clauses""" + return list(zip(self.preds[1:], self.branch_fns[1:])) + + def __call_capture_disabled(self, *args, **kwargs): + # python fallback + for pred, branch_fn in zip(self.preds, self.branch_fns): + if pred: + return branch_fn(*args, **kwargs) + + return self.false_fn(*args, **kwargs) # pylint: disable=not-callable + + def __call_capture_enabled(self, *args, **kwargs): + + import jax # pylint: disable=import-outside-toplevel + + cond_prim = _get_cond_qfunc_prim() + + elifs = ( + (self.orig_elifs,) + if len(self.orig_elifs) > 0 and not isinstance(self.orig_elifs[0], tuple) + else self.orig_elifs + ) + + @wraps(self.true_fn) + def new_wrapper(*args, **kwargs): + + jaxpr_true = jax.make_jaxpr(functools.partial(self.true_fn, **kwargs))(*args) + jaxpr_false = ( + jax.make_jaxpr(functools.partial(self.otherwise_fn, **kwargs))(*args) + if self.otherwise_fn + else None + ) + + # We extract each condition (or predicate) from the elifs argument list + # since these are traced by JAX and are passed as positional arguments to the primitive + elifs_conditions = [] + jaxpr_elifs = [] + + for pred, elif_fn in elifs: + elifs_conditions.append(pred) + jaxpr_elifs.append(jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args)) + + conditions = jax.numpy.array([self.condition, *elifs_conditions, True]) + + jaxpr_branches = [jaxpr_true, *jaxpr_elifs, jaxpr_false] + jaxpr_consts = [jaxpr.consts if jaxpr is not None else () for jaxpr in jaxpr_branches] + + # We need to flatten the constants since JAX does not allow + # to pass lists as positional arguments + consts_flat = [const for sublist in jaxpr_consts for const in sublist] + n_consts_per_branch = [len(consts) for consts in jaxpr_consts] + + return cond_prim.bind( + conditions, + *args, + *consts_flat, + jaxpr_branches=jaxpr_branches, + n_consts_per_branch=n_consts_per_branch, + n_args=len(args), + ) + + return new_wrapper(*args, **kwargs) + + def __call__(self, *args, **kwargs): + + if qml.capture.enabled(): + return self.__call_capture_enabled(*args, **kwargs) + + return self.__call_capture_disabled(*args, **kwargs) + + +def cond(condition, true_fn: Callable = None, false_fn: Optional[Callable] = None, elifs=()): """Quantum-compatible if-else conditionals --- condition quantum operations on parameters such as the results of mid-circuit qubit measurements. @@ -139,15 +307,15 @@ def cond(condition, true_fn: Callable, false_fn: Optional[Callable] = None, elif If a branch returns one or more variables, every other branch must return the same abstract values. Args: - condition (Union[.MeasurementValue, bool]): a conditional expression involving a mid-circuit - measurement value (see :func:`.pennylane.measure`). This can only be of type ``bool`` when - decorated by :func:`~.qjit`. + condition (Union[.MeasurementValue, bool]): a conditional expression that may involve a mid-circuit + measurement value (see :func:`.pennylane.measure`). true_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``True`` false_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``False`` elifs (List(Tuple(bool, callable))): A list of (bool, elif_fn) clauses. Can only - be used when decorated by :func:`~.qjit`. + be used when decorated by :func:`~.qjit` or if the condition is not + a mid-circuit measurement. Returns: function: A new function that applies the conditional equivalent of ``true_fn``. The returned @@ -380,6 +548,10 @@ def qnode(a, x, y, z): if active_jit := compiler.active_compiler(): available_eps = compiler.AvailableCompilers.names_entrypoints ops_loader = available_eps[active_jit]["ops"].load() + + if true_fn is None: + return ops_loader.cond(condition) + cond_func = ops_loader.cond(condition)(true_fn) # Optional 'elif' branches @@ -392,11 +564,26 @@ def qnode(a, x, y, z): return cond_func - if qml.capture.enabled(): - return _capture_cond(condition, true_fn, false_fn, elifs) + if not isinstance(condition, MeasurementValue): + # The condition is not a mid-circuit measurement. + if true_fn is None: + return lambda fn: CondCallable(condition, fn) + + return CondCallable(condition, true_fn, false_fn, elifs) + + if true_fn is None: + raise TypeError( + "cond missing 1 required positional argument: 'true_fn'.\n" + "Note that if the conditional includes a mid-circuit measurement, " + "qml.cond cannot be used as a decorator.\n" + "Instead, please use the form qml.cond(condition, true_fn, false_fn)." + ) if elifs: - raise ConditionalTransformError("'elif' branches are not supported in interpreted mode.") + raise ConditionalTransformError( + "'elif' branches are not supported when not using @qjit and the " + "conditional include mid-circuit measurements." + ) if callable(true_fn): # We assume that the callable is an operation or a quantum function @@ -519,51 +706,3 @@ def _(*_, jaxpr_branches, **__): return outvals_true return cond_prim - - -def _capture_cond(condition, true_fn, false_fn=None, elifs=()) -> Callable: - """Capture compatible way to apply conditionals.""" - - import jax # pylint: disable=import-outside-toplevel - - cond_prim = _get_cond_qfunc_prim() - - elifs = (elifs,) if len(elifs) > 0 and not isinstance(elifs[0], tuple) else elifs - - @wraps(true_fn) - def new_wrapper(*args, **kwargs): - - jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args) - jaxpr_false = ( - jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None - ) - - # We extract each condition (or predicate) from the elifs argument list - # since these are traced by JAX and are passed as positional arguments to the primitive - elifs_conditions = [] - jaxpr_elifs = [] - - for pred, elif_fn in elifs: - elifs_conditions.append(pred) - jaxpr_elifs.append(jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args)) - - conditions = jax.numpy.array([condition, *elifs_conditions, True]) - - jaxpr_branches = [jaxpr_true, *jaxpr_elifs, jaxpr_false] - jaxpr_consts = [jaxpr.consts if jaxpr is not None else () for jaxpr in jaxpr_branches] - - # We need to flatten the constants since JAX does not allow - # to pass lists as positional arguments - consts_flat = [const for sublist in jaxpr_consts for const in sublist] - n_consts_per_branch = [len(consts) for consts in jaxpr_consts] - - return cond_prim.bind( - conditions, - *args, - *consts_flat, - jaxpr_branches=jaxpr_branches, - n_consts_per_branch=n_consts_per_branch, - n_args=len(args), - ) - - return new_wrapper diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index a6ddfa63dce..a3a70486157 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -22,7 +22,7 @@ import pytest import pennylane as qml -from pennylane.ops.op_math.condition import _capture_cond +from pennylane.ops.op_math.condition import CondCallable pytestmark = pytest.mark.jax @@ -229,7 +229,7 @@ class TestCondReturns: def test_validate_mismatches(self, true_fn, false_fn, expected_error, match): """Test mismatch in number and type of output variables.""" with pytest.raises(expected_error, match=match): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + jax.make_jaxpr(CondCallable(True, true_fn, false_fn))(jax.numpy.array(1)) def test_validate_number_of_output_variables(self): """Test mismatch in number of output variables.""" @@ -241,7 +241,7 @@ def false_fn(x): return x + 1 with pytest.raises(ValueError, match=r"Mismatch in number of output variables"): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + jax.make_jaxpr(CondCallable(True, true_fn, false_fn))(jax.numpy.array(1)) def test_validate_output_variable_types(self): """Test mismatch in output variable types.""" @@ -253,7 +253,7 @@ def false_fn(x): return x + 1, x + 2.0 with pytest.raises(ValueError, match=r"Mismatch in output abstract values"): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + jax.make_jaxpr(CondCallable(True, true_fn, false_fn))(jax.numpy.array(1)) def test_validate_no_false_branch_with_return(self): """Test no false branch provided with return variables.""" @@ -265,7 +265,7 @@ def true_fn(x): ValueError, match=r"The false branch must be provided if the true branch returns any variables", ): - jax.make_jaxpr(_capture_cond(True, true_fn))(jax.numpy.array(1)) + jax.make_jaxpr(CondCallable(True, true_fn))(jax.numpy.array(1)) def test_validate_no_false_branch_with_return_2(self): """Test no false branch provided with return variables.""" @@ -280,9 +280,7 @@ def elif_fn(x): ValueError, match=r"The false branch must be provided if the true branch returns any variables", ): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn=None, elifs=(False, elif_fn)))( - jax.numpy.array(1) - ) + jax.make_jaxpr(CondCallable(True, true_fn, elifs=[(True, elif_fn)]))(jax.numpy.array(1)) def test_validate_elif_branches(self): """Test elif branch mismatches.""" @@ -306,13 +304,13 @@ def elif_fn3(x): ValueError, match=r"Mismatch in output abstract values in elif branch #1" ): jax.make_jaxpr( - _capture_cond(False, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)]) + CondCallable(True, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)]) )(jax.numpy.array(1)) with pytest.raises( ValueError, match=r"Mismatch in number of output variables in elif branch #0" ): - jax.make_jaxpr(_capture_cond(False, true_fn, false_fn, [(True, elif_fn3)]))( + jax.make_jaxpr(CondCallable(True, true_fn, false_fn, elifs=[(True, elif_fn3)]))( jax.numpy.array(1) ) @@ -328,7 +326,7 @@ def true_fn(): qml.RX(0.1, wires=0) qml.cond(pred > 0, true_fn)() - return qml.expval(qml.PauliZ(wires=0)) + return qml.expval(qml.Z(wires=0)) @qml.qnode(dev) @@ -352,7 +350,7 @@ def elif_fn1(arg1, arg2): qml.cond(pred > 0, true_fn, false_fn, elifs=(pred == -1, elif_fn1))(arg1, arg2) qml.RX(0.10, wires=0) - return qml.expval(qml.PauliZ(wires=0)) + return qml.expval(qml.Z(wires=0)) @qml.qnode(dev) @@ -371,7 +369,7 @@ def false_fn(arg1, arg2): qml.cond(pred > 0, true_fn, false_fn)(arg1, arg2) qml.RX(0.10, wires=0) - return qml.expval(qml.PauliZ(wires=0)) + return qml.expval(qml.Z(wires=0)) @qml.qnode(dev) diff --git a/tests/devices/test_device.py b/tests/devices/test_device.py index a5f6631eb90..b84723d634a 100644 --- a/tests/devices/test_device.py +++ b/tests/devices/test_device.py @@ -581,21 +581,6 @@ def test_mcm_unsupported_error(self, mock_device_with_paulis_and_methods): with pytest.raises(DeviceError, match="Mid-circuit measurements are not natively"): dev.check_validity(tape.operations, tape.observables) - def test_conditional_ops_unsupported_error(self, mock_device_with_paulis_and_methods): - """Test that an error is raised for conditional operations if - mid-circuit measurements are not supported natively""" - dev = mock_device_with_paulis_and_methods(wires=2) - - with qml.queuing.AnnotatedQueue() as q: - qml.cond(0, qml.RY)(0.3, wires=0) - qml.PauliZ(0) - - tape = qml.tape.QuantumScript.from_queue(q) - # Raises an error for device that doesn't support conditional - # operations natively - with pytest.raises(DeviceError, match="Gate Conditional\\(RY\\) not supported on device"): - dev.check_validity(tape.operations, tape.observables) - @pytest.mark.parametrize( "wires, subset, expected_subset", [ diff --git a/tests/ops/op_math/test_condition.py b/tests/ops/op_math/test_condition.py index c8fb8ab925e..2db00e49fc5 100644 --- a/tests/ops/op_math/test_condition.py +++ b/tests/ops/op_math/test_condition.py @@ -495,3 +495,177 @@ def test_adjoint(self): assert isinstance(adj_op, Conditional) assert adj_op.meas_val == op.meas_val assert adj_op.base == base.adjoint() + + +class TestPythonFallback: + """Test python fallback""" + + def test_simple_if(self): + """Test a simple if statement""" + + def f(x): + c = qml.cond(x > 1, np.sin) + assert c.true_fn is np.sin + assert c.condition is (x > 1) + return c(x) + + assert np.allclose(f(1.5), np.sin(1.5)) + assert f(0.5) is None + + def test_simple_if_else(self): + """Test a simple if-else statement""" + + def f(x): + c = qml.cond(x > 1, np.sin, np.cos) + assert c.false_fn is np.cos + return c(x) + + assert np.allclose(f(1.5), np.sin(1.5)) + assert np.allclose(f(0.5), np.cos(0.5)) + + def test_simple_if_elif_else(self): + """Test a simple if-elif-else statement""" + + def f(x): + elifs = [(x >= -1, lambda y: y**2), (x > -10, lambda y: y**3)] + c = qml.cond(x > 1, np.sin, np.cos, elifs) + return c(x) + + assert np.allclose(f(1.5), np.sin(1.5)) + assert np.allclose(f(-0.5), (-0.5) ** 2) + assert np.allclose(f(-5), (-5) ** 3) + assert np.allclose(f(-10.5), np.cos(-10.5)) + + def test_simple_if_elif_else_order(self): + """Test a simple if-elif-else statement where the order of the elif + statements matter""" + + def f(x): + elifs = [(x > -10, lambda y: y**3), (x >= -1, lambda y: y**2)] + c = qml.cond(x > 1, np.sin, np.cos, elifs) + + for i, j in zip(c.elifs, elifs): + assert i[0] is j[0] + assert i[1] is j[1] + + return c(x) + + assert np.allclose(f(1.5), np.sin(1.5)) + assert np.allclose(f(-0.5), (-0.5) ** 3) + assert np.allclose(f(-5), (-5) ** 3) + assert np.allclose(f(-10.5), np.cos(-10.5)) + + def test_decorator_syntax_if(self): + """test a decorator if statement""" + + def f(x): + @qml.cond(x > 0) + def conditional(y): + return y**2 + + return conditional(x + 1) + + assert np.allclose(f(0.5), (0.5 + 1) ** 2) + assert f(-0.5) is None + + def test_decorator_syntax_if_else(self): + """test a decorator if-else statement""" + + def f(x): + @qml.cond(x > 0) + def conditional(y): + return y**2 + + @conditional.otherwise + def conditional_false_fn(y): # pylint: disable=unused-variable + return -y + + return conditional(x + 1) + + assert np.allclose(f(0.5), (0.5 + 1) ** 2) + assert np.allclose(f(-0.5), -(-0.5 + 1)) + + def test_decorator_syntax_if_elif_else(self): + """test a decorator if-elif-else statement""" + + def f(x): + @qml.cond(x > 0) + def conditional(y): + return y**2 + + @conditional.else_if(x < -2) + def conditional_elif(y): # pylint: disable=unused-variable + return y + + @conditional.otherwise + def conditional_false_fn(y): # pylint: disable=unused-variable + return -y + + return conditional(x + 1) + + assert np.allclose(f(0.5), (0.5 + 1) ** 2) + assert np.allclose(f(-0.5), -(-0.5 + 1)) + assert np.allclose(f(-2.5), (-2.5 + 1)) + + def test_error_mcms_elif(self): + """Test that an error is raised if elifs are provided + when the conditional includes an MCM""" + dev = qml.device("default.qubit") + + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + m = qml.measure(0) + qml.cond(m, qml.RX, elifs=[(~m, qml.RY)]) + return qml.probs + + with pytest.raises(ConditionalTransformError, match="'elif' branches are not supported"): + circuit(0.5) + + def test_error_no_true_fn(self): + """Test that an error is raised if no true_fn is provided + when the conditional includes an MCM""" + dev = qml.device("default.qubit") + + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + m = qml.measure(0) + + @qml.cond(m) + def conditional(): + qml.RZ(x**2) + + conditional() + return qml.probs + + with pytest.raises(TypeError, match="cannot be used as a decorator"): + circuit(0.5) + + def test_qnode(self): + """Test that qml.cond fallsback to Python when used + within a QNode""" + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev) + def circuit(x): + elifs = [(x > 1.4, lambda y, wires: qml.RY(y**2, wires=wires))] + c = qml.cond(x > 2.7, qml.RX, qml.RZ, elifs) + c(x, wires=0) + return qml.probs(wires=0) + + circuit(3) + ops = circuit.tape.operations + assert len(ops) == 1 + assert ops[0].name == "RX" + + circuit(2) + ops = circuit.tape.operations + assert len(ops) == 1 + assert ops[0].name == "RY" + assert np.allclose(ops[0].parameters[0], 2**2) + + circuit(1) + ops = circuit.tape.operations + assert len(ops) == 1 + assert ops[0].name == "RZ" diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 92130d5f3fd..5c02608286d 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -599,11 +599,31 @@ def elif_fn(): return qml.expval(qml.PauliZ(0)) - with pytest.raises( - ValueError, - match="'elif' branches are not supported in interpreted mode", - ): - circuit(1.5) + assert jnp.allclose(circuit(1.2), 1.0) + assert jnp.allclose(circuit(jnp.pi), -1.0) + + def test_cond_with_decorator_syntax(self): + """Test condition using the decorator syntax""" + + @qml.qjit + def f(x): + @qml.cond(x > 0) + def conditional(): + return (x + 1) ** 2 + + @conditional.else_if(x < -2) + def conditional_elif(): # pylint: disable=unused-variable + return x + 1 + + @conditional.otherwise + def conditional_false_fn(): # pylint: disable=unused-variable + return -(x + 1) + + return conditional() + + assert np.allclose(f(0.5), (0.5 + 1) ** 2) + assert np.allclose(f(-0.5), -(-0.5 + 1)) + assert np.allclose(f(-2.5), (-2.5 + 1)) class TestCatalystGrad: