From 08c588642cd1d04634b7136f68473aa5e76b4f21 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Wed, 17 Jul 2024 21:54:58 -0400 Subject: [PATCH 01/20] Allow `qml.for_loop` and `qml.while_loop` to fallback to the python interpreter if a compiler is not available. --- pennylane/compiler/qjit_api.py | 75 +++++++++++++++++++++++++++++++--- 1 file changed, 69 insertions(+), 6 deletions(-) diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index 140bca19b4b..e508f2b00d3 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -377,7 +377,37 @@ def loop_rx(x): ops_loader = compilers[active_jit]["ops"].load() return ops_loader.while_loop(cond_fn) - raise CompileError("There is no active compiler package.") # pragma: no cover + # if there is no active compiler, simply interpret the while loop + # via the Python interpretor. + def _decorator(body_fn): + return WhileLoopCallable(cond_fn, body_fn) + + return _decorator + + +class WhileLoopCallable: + """Base class to represent a while loop. This class + when called with an initial state will execute the while + loop via the Python interpreter. + + Args: + cond_fn (Callable): the condition function in the while loop + body_fn (Callable): the function that is executed within the while loop + """ + + def __init__(self, cond_fn, body_fn): + self.cond_fn = cond_fn + self.body_fn = body_fn + + def __call__(self, *init_state): + args = init_state + fn_res = args if len(args) > 1 else args[0] if len(args) == 1 else None + + while self.cond_fn(*args): + fn_res = self.body_fn(*args) + args = fn_res if len(args) > 1 else (fn_res,) if len(args) == 1 else () + + return fn_res def for_loop(lower_bound, upper_bound, step): @@ -430,14 +460,10 @@ def for_loop(lower_bound, upper_bound, step, loop_fn, *args): across iterations is handled automatically by the provided loop bounds, it must not be returned from the function. - Raises: - CompileError: if the compiler is not installed - .. seealso:: :func:`~.while_loop`, :func:`~.qjit` **Example** - .. code-block:: python dev = qml.device("lightning.qubit", wires=1) @@ -468,4 +494,41 @@ def loop_rx(i, x): ops_loader = compilers[active_jit]["ops"].load() return ops_loader.for_loop(lower_bound, upper_bound, step) - raise CompileError("There is no active compiler package.") # pragma: no cover + # if there is no active compiler, simply interpret the for loop + # via the Python interpretor. + def _decorator(body_fn): + return ForLoopCallable(lower_bound, upper_bound, step, body_fn) + + return _decorator + +class ForLoopCallable: + """Base class to represent a for loop. This class + when called with an initial state will execute the while + loop via the Python interpreter. + + Args: + lower_bound (int): starting value of the iteration index + upper_bound (int): (exclusive) upper bound of the iteration index + step (int): increment applied to the iteration index at the end of each iteration + body_fn (Callable): The function called within the for loop. Note that the loop body + function must always have the iteration index as its first + argument, which can be used arbitrarily inside the loop body. As the value of the index + across iterations is handled automatically by the provided loop bounds, it must not be + returned from the function. + """ + + def __init__(self, lower_bound, upper_bound, step, body_fn): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + self.step = step + self.body_fn = body_fn + + def __call__(self, *init_state): + args = init_state + fn_res = args if len(args) > 1 else args[0] if len(args) == 1 else None + + for i in range(self.lower_bound, self.upper_bound, self.step): + fn_res = self.body_fn(i, *args) + args = fn_res if len(args) > 1 else (fn_res,) if len(args) == 1 else () + + return fn_res From d48ba18d2c5facaa8396156bffeca07026741403 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sat, 20 Jul 2024 21:56:21 -0400 Subject: [PATCH 02/20] add tests --- pennylane/compiler/qjit_api.py | 7 ++-- tests/test_compiler.py | 71 ++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index e508f2b00d3..9cefcb11445 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -385,7 +385,7 @@ def _decorator(body_fn): return _decorator -class WhileLoopCallable: +class WhileLoopCallable: # pylint:disable=too-few-public-methods """Base class to represent a while loop. This class when called with an initial state will execute the while loop via the Python interpreter. @@ -501,7 +501,8 @@ def _decorator(body_fn): return _decorator -class ForLoopCallable: + +class ForLoopCallable: # pylint:disable=too-few-public-methods """Base class to represent a for loop. This class when called with an initial state will execute the while loop via the Python interpreter. @@ -526,7 +527,7 @@ def __init__(self, lower_bound, upper_bound, step, body_fn): def __call__(self, *init_state): args = init_state fn_res = args if len(args) > 1 else args[0] if len(args) == 1 else None - + for i in range(self.lower_bound, self.upper_bound, self.step): fn_res = self.body_fn(i, *args) args = fn_res if len(args) > 1 else (fn_res,) if len(args) == 1 else () diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 4b66bf2d061..a562efa6fda 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -353,6 +353,24 @@ def inner(j): assert circuit(5, 6) == 30 # 5 * 6 assert circuit(4, 7) == 28 # 4 * 7 + def test_while_loop_python_fallback(self): + """Test that qml.while_loop fallsback to + Python without qjit""" + + def f(n, m): + @qml.while_loop(lambda i, _: i < n) + def outer(i, sm): + @qml.while_loop(lambda j: j < m) + def inner(j): + return j + 1 + + return i + 1, sm + inner(0) + + return outer(0, 0)[1] + + assert f(5, 6) == 30 # 5 * 6 + assert f(4, 7) == 28 # 4 * 7 + def test_dynamic_wires_for_loops(self): """Test for loops with iteration index-dependant wires.""" dev = qml.device("lightning.qubit", wires=6) @@ -405,6 +423,59 @@ def inner(j): assert jnp.allclose(circuit(4), jnp.eye(2**4)[0]) + def test_for_loop_python_fallback(self, mocker): + """Test that qml.for_loop fallsback to Python + interpretation if Catalyst is not available""" + mocker.patch('pennylane.compiler.available', return_value=False) + + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(dev) + def circuit(x, n): + + # for loop with dynamic bounds + @qml.for_loop(0, n, 1) + def loop_fn(i): + qml.Hadamard(wires=i) + + # nested for loops. + # outer for loop updates x + @qml.for_loop(0, n, 1) + def loop_fn_returns(i, x): + qml.RX(x, wires=i) + + # inner for loop + @qml.for_loop(i + 1, n, 1) + def inner(j): + qml.CRY(x ** 2, [i, j]) + + inner() + + return jnp.sin(x) + + loop_fn() + loop_fn_returns(x) + + return qml.expval(qml.PauliZ(0)) + + x = 0.5 + assert jnp.allclose(circuit(x, 2), qml.qjit(circuit)(x, 2)) + + res = circuit.tape.operations + expected = [ + qml.Hadamard(wires=[0]), + qml.Hadamard(wires=[1]), + qml.Hadamard(wires=[2]), + qml.RX(0.5, wires=[0]), + qml.CRY(0.25, wires=[0, 1]), + qml.CRY(0.25, wires=[0, 2]), + qml.RX(0.6, wires=[1]), + qml.CRY(0.36, wires=[1, 2]), + qml.RX(0.7, wires=[2]) + ] + + assert [qml.equal(i, j) for i, j in zip(res, expected)] + def test_cond(self): """Test condition with simple true_fn""" dev = qml.device("lightning.qubit", wires=1) From ea87c28052d4f28243aa98c815f995b00589aa3a Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sat, 20 Jul 2024 22:12:25 -0400 Subject: [PATCH 03/20] black test compiler --- tests/test_compiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index a562efa6fda..394acc87ccf 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -426,7 +426,7 @@ def inner(j): def test_for_loop_python_fallback(self, mocker): """Test that qml.for_loop fallsback to Python interpretation if Catalyst is not available""" - mocker.patch('pennylane.compiler.available', return_value=False) + mocker.patch("pennylane.compiler.available", return_value=False) dev = qml.device("lightning.qubit", wires=2) @@ -447,7 +447,7 @@ def loop_fn_returns(i, x): # inner for loop @qml.for_loop(i + 1, n, 1) def inner(j): - qml.CRY(x ** 2, [i, j]) + qml.CRY(x**2, [i, j]) inner() @@ -471,7 +471,7 @@ def inner(j): qml.CRY(0.25, wires=[0, 2]), qml.RX(0.6, wires=[1]), qml.CRY(0.36, wires=[1, 2]), - qml.RX(0.7, wires=[2]) + qml.RX(0.7, wires=[2]), ] assert [qml.equal(i, j) for i, j in zip(res, expected)] From 51ec3c0a54c7fa8a8b383ad5183d95084de15da6 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sat, 20 Jul 2024 23:33:26 -0400 Subject: [PATCH 04/20] Allow `qml.cond` to fallback to the Python interpreter if a compiler is not available and there are no MCMs --- pennylane/ops/op_math/condition.py | 127 ++++++++++++++++++++++++- tests/ops/op_math/test_condition.py | 139 ++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+), 4 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index d3da2814719..d65df8963ac 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -20,6 +20,7 @@ from pennylane import QueuingManager from pennylane.compiler import compiler from pennylane.operation import AnyWires, Operation, Operator +from pennylane.measurements import MeasurementValue from pennylane.ops.op_math.symbolicop import SymbolicOp from pennylane.tape import make_qscript @@ -100,7 +101,111 @@ def adjoint(self): return Conditional(self.meas_val, self.base.adjoint()) -def cond(condition, true_fn, false_fn=None, elifs=()): +class CondCallable: # pylint:disable=too-few-public-methods + """Base class to represent a conditional function with boolean predicates. + + Args: + condition (bool): a conditional expression + true_fn (callable): The function to apply if ``condition`` is ``True`` + false_fn (callable): The function to apply if ``condition`` is ``False`` + elifs (List(Tuple(bool, callable))): A list of (bool, elif_fn) clauses. + + Passing ``true_fn``, ``false_fn``, and ``elifs`` on initialization + is optional; these functions can be registered post-initialization + via decorators: + + .. code-block:: python + + def f(x): + @qml.cond(x > 0) + def conditional(y): + return y ** 2 + + @conditional.else_if(x < -2) + def conditional(y): + return y + + @conditional.otherwise + def conditional_false_fn(y): + return -y + + return conditional(x + 1) + + >>> [f(0.5), f(-3), f(-0.5)] + [2.25, -2, -0.5] + """ + + def __init__(self, condition, true_fn, false_fn=None, elifs=()): + self.preds = [condition] + self.branch_fns = [true_fn] + self.otherwise_fn = false_fn + + if false_fn is None: + self.otherwise_fn = lambda *args, **kwargs: None + + if elifs: + elif_preds, elif_fns = list(zip(*elifs)) + self.preds.extend(elif_preds) + self.branch_fns.extend(elif_fns) + + def else_if(self, pred): + """Decorator that allows else-if functions to be registered with a corresponding + boolean predicate. + + Args: + pred (bool): The predicate that will determine if this branch is executed. + + Returns: + callable: decorator that is applied to the else-if function + """ + + def decorator(branch_fn): + self.preds.append(pred) + self.branch_fns.append(branch_fn) + return self + + return decorator + + def otherwise(self, otherwise_fn): + """Decorator that registers the function to be run if all + conditional predicates (including optional) evaluates to ``False``. + + Args: + otherwise_fn (callable): the function to apply if all ``self.preds`` evaluate to ``False`` + """ + self.otherwise_fn = otherwise_fn + return self + + @property + def false_fn(self): + """callable: the function to apply if all ``self.preds`` evaluate to ``False``""" + return self.otherwise_fn + + @property + def true_fn(self): + """callable: the function to apply if all ``self.condition`` evaluate to ``True``""" + return self.branch_fns[0] + + @property + def condition(self): + """bool: the condition that determines if ``self.true_fn`` is applied""" + return self.preds[0] + + @property + def elifs(self): + """(List(Tuple(bool, callable))): a list of (bool, elif_fn) clauses""" + return list(zip(self.preds[1:], self.branch_fns[1:])) + + def __call__(self, *args, **kwargs): + # python fallback + for pred, branch_fn in zip(self.preds, self.branch_fns): + if pred: + return branch_fn(*args, **kwargs) + + return self.false_fn(*args, **kwargs) + + +def cond(condition, true_fn=None, false_fn=None, elifs=()): """Quantum-compatible if-else conditionals --- condition quantum operations on parameters such as the results of mid-circuit qubit measurements. @@ -128,14 +233,14 @@ def cond(condition, true_fn, false_fn=None, elifs=()): Args: condition (Union[.MeasurementValue, bool]): a conditional expression involving a mid-circuit - measurement value (see :func:`.pennylane.measure`). This can only be of type ``bool`` when - decorated by :func:`~.qjit`. + measurement value (see :func:`.pennylane.measure`). true_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``True`` false_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``False`` elifs (List(Tuple(bool, callable))): A list of (bool, elif_fn) clauses. Can only - be used when decorated by :func:`~.qjit`. + be used when decorated by :func:`~.qjit` or if the condition is not + a mid-circuit measurement. Returns: function: A new function that applies the conditional equivalent of ``true_fn``. The returned @@ -367,6 +472,10 @@ def qnode(a, x, y, z): if active_jit := compiler.active_compiler(): available_eps = compiler.AvailableCompilers.names_entrypoints ops_loader = available_eps[active_jit]["ops"].load() + + if true_fn is None: + return lambda fn: ops_loader.cond(condition)(fn) + cond_func = ops_loader.cond(condition)(true_fn) # Optional 'elif' branches @@ -379,6 +488,16 @@ def qnode(a, x, y, z): return cond_func + if not isinstance(condition, MeasurementValue): + # The condition is not a mid-circuit measurement. + if true_fn is None: + return lambda fn: CondCallable(condition, fn) + + return CondCallable(condition, true_fn, false_fn, elifs) + + if true_fn is None: + raise TypeError("cond missing 1 required positional argument: 'true_fn'") + if elifs: raise ConditionalTransformError("'elif' branches are not supported in interpreted mode.") diff --git a/tests/ops/op_math/test_condition.py b/tests/ops/op_math/test_condition.py index c8fb8ab925e..2429416e5f3 100644 --- a/tests/ops/op_math/test_condition.py +++ b/tests/ops/op_math/test_condition.py @@ -495,3 +495,142 @@ def test_adjoint(self): assert isinstance(adj_op, Conditional) assert adj_op.meas_val == op.meas_val assert adj_op.base == base.adjoint() + + +class TestPythonFallback: + """Test python fallback""" + + def test_simple_if(self): + """Test a simple if statement""" + + def f(x): + c = qml.cond(x > 1, np.sin) + assert c.true_fn is np.sin + assert c.condition is (x > 1) + return c(x) + + assert np.allclose(f(1.5), np.sin(1.5)) + assert f(0.5) is None + + def test_simple_if_else(self): + """Test a simple if-else statement""" + + def f(x): + c = qml.cond(x > 1, np.sin, np.cos) + assert c.false_fn is np.cos + return c(x) + + assert np.allclose(f(1.5), np.sin(1.5)) + assert np.allclose(f(0.5), np.cos(0.5)) + + def test_simple_if_elif_else(self): + """Test a simple if-elif-else statement""" + + def f(x): + elifs = [(x >= -1, lambda y: y**2), (x > -10, lambda y: y**3)] + c = qml.cond(x > 1, np.sin, np.cos, elifs) + return c(x) + + assert np.allclose(f(1.5), np.sin(1.5)) + assert np.allclose(f(-0.5), (-0.5) ** 2) + assert np.allclose(f(-5), (-5) ** 3) + assert np.allclose(f(-10.5), np.cos(-10.5)) + + def test_simple_if_elif_else_order(self): + """Test a simple if-elif-else statement where the order of the elif + statements matter""" + + def f(x): + elifs = [(x > -10, lambda y: y**3), (x >= -1, lambda y: y**2)] + c = qml.cond(x > 1, np.sin, np.cos, elifs) + + for i, j in zip(c.elifs, elifs): + assert i[0] is j[0] + assert i[1] is j[1] + + return c(x) + + assert np.allclose(f(1.5), np.sin(1.5)) + assert np.allclose(f(-0.5), (-0.5) ** 3) + assert np.allclose(f(-5), (-5) ** 3) + assert np.allclose(f(-10.5), np.cos(-10.5)) + + def test_decorator_syntax_if(self): + """test a decorator if statement""" + + def f(x): + @qml.cond(x > 0) + def conditional(y): + return y**2 + + return conditional(x + 1) + + assert np.allclose(f(0.5), (0.5 + 1) ** 2) + assert f(-0.5) is None + + def test_decorator_syntax_if_else(self): + """test a decorator if-else statement""" + + def f(x): + @qml.cond(x > 0) + def conditional(y): + return y**2 + + @conditional.otherwise + def conditional_false_fn(y): + return -y + + return conditional(x + 1) + + assert np.allclose(f(0.5), (0.5 + 1) ** 2) + assert np.allclose(f(-0.5), -(-0.5 + 1)) + + def test_decorator_syntax_if_elif_else(self): + """test a decorator if-elif-else statement""" + + def f(x): + @qml.cond(x > 0) + def conditional(y): + return y**2 + + @conditional.else_if(x < -2) + def conditional(y): + return y + + @conditional.otherwise + def conditional_false_fn(y): + return -y + + return conditional(x + 1) + + assert np.allclose(f(0.5), (0.5 + 1) ** 2) + assert np.allclose(f(-0.5), -(-0.5 + 1)) + assert np.allclose(f(-2.5), (-2.5 + 1)) + + def test_qnode(self): + """Test that qml.cond fallsback to Python when used + within a QNode""" + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev) + def circuit(x): + elifs = [(x > 1.4, lambda y, wires: qml.RY(y**2, wires=wires))] + c = qml.cond(x > 2.7, qml.RX, qml.RZ, elifs) + c(x, wires=0) + return qml.probs(wires=0) + + circuit(3) + ops = circuit.tape.operations + assert len(ops) == 1 + assert ops[0].name == "RX" + + circuit(2) + ops = circuit.tape.operations + assert len(ops) == 1 + assert ops[0].name == "RY" + assert np.allclose(ops[0].parameters[0], 2**2) + + circuit(1) + ops = circuit.tape.operations + assert len(ops) == 1 + assert ops[0].name == "RZ" From b8e6f9f091f38c9d057e9636d51919b5c40858b6 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sat, 20 Jul 2024 23:48:22 -0400 Subject: [PATCH 05/20] linting and changelog --- doc/releases/changelog-dev.md | 22 ++++++++++++++++++++++ tests/ops/op_math/test_condition.py | 6 +++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 8559de76b3b..e9503e01cb0 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -51,6 +51,28 @@ * Molecules and Hamiltonians can now be constructed for all the elements present in the periodic table. [(#5821)](https://github.com/PennyLaneAI/pennylane/pull/5821) +* If the conditional does not include a mid-circuit measurement, then `qml.cond` + will automatically evaluate conditionals using standard Python control flow. + + This allows `qml.cond` to be used to represent a wider range of conditionals: + + ```python + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev) + def circuit(x): + c = qml.cond(x > 2.7, qml.RX, qml.RZ) + c(x, wires=0) + return qml.probs(wires=0) + ``` + + ```pycon + >>> print(qml.draw(circuit)(3.8)) + 0: ──RX(3.80)─┤ Probs + >>> print(qml.draw(circuit)(0.54)) + 0: ──RZ(0.54)─┤ Probs + ``` +

Community contributions 🥳

* `DefaultQutritMixed` readout error has been added using parameters `readout_relaxation_probs` and diff --git a/tests/ops/op_math/test_condition.py b/tests/ops/op_math/test_condition.py index 2429416e5f3..d6c6a12994f 100644 --- a/tests/ops/op_math/test_condition.py +++ b/tests/ops/op_math/test_condition.py @@ -577,7 +577,7 @@ def conditional(y): return y**2 @conditional.otherwise - def conditional_false_fn(y): + def conditional_false_fn(y): # pylint: disable=unused-variable return -y return conditional(x + 1) @@ -594,11 +594,11 @@ def conditional(y): return y**2 @conditional.else_if(x < -2) - def conditional(y): + def conditional_elif(y): return y @conditional.otherwise - def conditional_false_fn(y): + def conditional_false_fn(y): # pylint: disable=unused-variable return -y return conditional(x + 1) From 93afc09cffd610c5f5a885ce18f6178f32287aea Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sat, 20 Jul 2024 23:55:17 -0400 Subject: [PATCH 06/20] changelog --- doc/releases/changelog-dev.md | 46 +++++++++++++++++++++++++++++++++++ tests/test_compiler.py | 4 +-- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 8559de76b3b..d0bb829d7bb 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -51,6 +51,51 @@ * Molecules and Hamiltonians can now be constructed for all the elements present in the periodic table. [(#5821)](https://github.com/PennyLaneAI/pennylane/pull/5821) +* `qml.for_loop` and `qml.while_loop` now fallback to standard Python control + flow if `@qjit` is not present, allowing the same code to work with and without + `@qjit` without any rewrites. + [(#6014)](https://github.com/PennyLaneAI/pennylane/pull/6014) + + ```python + dev = qml.device("lightning.qubit", wires=3) + + @qml.qnode(dev) + def circuit(x, n): + + @qml.for_loop(0, n, 1) + def init_state(i): + qml.Hadamard(wires=i) + + init_state() + + @qml.for_loop(0, n, 1) + def apply_operations(i, x): + qml.RX(x, wires=i) + + @qml.for_loop(i + 1, n, 1) + def inner(j): + qml.CRY(x**2, [i, j]) + + inner() + return jnp.sin(x) + + apply_operations(x) + return qml.probs() + ``` + + ```pycon + >>> print(qml.draw(circuit)(0.5, 3)) + 0: ──H──RX(0.50)─╭●────────╭●──────────────────────────────────────┤ Probs + 1: ──H───────────╰RY(0.25)─│──────────RX(0.48)─╭●──────────────────┤ Probs + 2: ──H─────────────────────╰RY(0.25)───────────╰RY(0.23)──RX(0.46)─┤ Probs + >>> circuit(0.5, 3) + array([0.125 , 0.125 , 0.09949758, 0.15050242, 0.07594666, + 0.11917543, 0.08942104, 0.21545687]) + >>> qml.qjit(circuit)(0.5, 3) + Array([0.125 , 0.125 , 0.09949758, 0.15050242, 0.07594666, + 0.11917543, 0.08942104, 0.21545687], dtype=float64) + ``` +

Community contributions 🥳

* `DefaultQutritMixed` readout error has been added using parameters `readout_relaxation_probs` and @@ -151,6 +196,7 @@ Lillian M. A. Frederiksen, Pietropaolo Frisoni, Emiliano Godinez, Renke Huang, +Josh Izaac, Christina Lee, Austin Huang, Christina Lee, diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 394acc87ccf..4bcf5b4a150 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -423,11 +423,9 @@ def inner(j): assert jnp.allclose(circuit(4), jnp.eye(2**4)[0]) - def test_for_loop_python_fallback(self, mocker): + def test_for_loop_python_fallback(self): """Test that qml.for_loop fallsback to Python interpretation if Catalyst is not available""" - mocker.patch("pennylane.compiler.available", return_value=False) - dev = qml.device("lightning.qubit", wires=2) @qml.qnode(dev) From e6bbaf9fafdfbb8b65c74f0567f7e02141728e88 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sat, 20 Jul 2024 23:55:53 -0400 Subject: [PATCH 07/20] changelog --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index e9503e01cb0..44f3bc1ac77 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -53,6 +53,7 @@ * If the conditional does not include a mid-circuit measurement, then `qml.cond` will automatically evaluate conditionals using standard Python control flow. + [(#6016)](https://github.com/PennyLaneAI/pennylane/pull/6016) This allows `qml.cond` to be used to represent a wider range of conditionals: From e7fbe7118718c32d433490e84c196f8ffdf10633 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sat, 20 Jul 2024 23:57:44 -0400 Subject: [PATCH 08/20] linting --- tests/ops/op_math/test_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ops/op_math/test_condition.py b/tests/ops/op_math/test_condition.py index d6c6a12994f..e3ebe997033 100644 --- a/tests/ops/op_math/test_condition.py +++ b/tests/ops/op_math/test_condition.py @@ -594,7 +594,7 @@ def conditional(y): return y**2 @conditional.else_if(x < -2) - def conditional_elif(y): + def conditional_elif(y): # pylint: disable=unused-variable return y @conditional.otherwise From 3d43fe8daa3d43378f73460e178e30d7d614363c Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sun, 21 Jul 2024 00:13:53 -0400 Subject: [PATCH 09/20] coverage --- pennylane/ops/op_math/condition.py | 7 ++++++- tests/ops/op_math/test_condition.py | 20 +++++++++++++++++++ tests/test_compiler.py | 30 ++++++++++++++++++++++++----- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index d65df8963ac..d6de3de58bf 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -496,7 +496,12 @@ def qnode(a, x, y, z): return CondCallable(condition, true_fn, false_fn, elifs) if true_fn is None: - raise TypeError("cond missing 1 required positional argument: 'true_fn'") + raise TypeError( + "cond missing 1 required positional argument: 'true_fn'.\n" + "Note that if the conditional includes a mid-circuit measurement," + "qml.cond cannot be used as a decorator.\n" + "Instead, please use the form qml.cond(condition, true_fn, false_fn)." + ) if elifs: raise ConditionalTransformError("'elif' branches are not supported in interpreted mode.") diff --git a/tests/ops/op_math/test_condition.py b/tests/ops/op_math/test_condition.py index e3ebe997033..b8f852d4569 100644 --- a/tests/ops/op_math/test_condition.py +++ b/tests/ops/op_math/test_condition.py @@ -607,6 +607,26 @@ def conditional_false_fn(y): # pylint: disable=unused-variable assert np.allclose(f(-0.5), -(-0.5 + 1)) assert np.allclose(f(-2.5), (-2.5 + 1)) + def test_error_no_true_fn(self): + """Test that an error is raised if no true_fn is provided + when the conditional includes an MCM""" + dev = qml.device("default.qubit") + + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + m = qml.measure(0) + + @qml.cond(m) + def conditional(): + qml.RZ(x**2) + + conditional() + return qml.probs + + with pytest.raises(TypeError, match="cannot be used as a decorator"): + circuit(0.5) + def test_qnode(self): """Test that qml.cond fallsback to Python when used within a QNode""" diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 4bcf5b4a150..b7213e0d62c 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -578,11 +578,31 @@ def elif_fn(): return qml.expval(qml.PauliZ(0)) - with pytest.raises( - ValueError, - match="'elif' branches are not supported in interpreted mode", - ): - circuit(1.5) + assert jnp.allclose(circuit(1.2), 1.0) + assert jnp.allclose(circuit(jnp.pi), -1.0) + + def test_cond_with_decorator_syntax(self): + """Test condition using the decorator syntax""" + + @qml.qjit + def f(x): + @qml.cond(x > 0) + def conditional(): + return (x + 1) ** 2 + + @conditional.else_if(x < -2) + def conditional_elif(): # pylint: disable=unused-variable + return x + 1 + + @conditional.otherwise + def conditional_false_fn(): # pylint: disable=unused-variable + return -(x + 1) + + return conditional() + + assert np.allclose(f(0.5), (0.5 + 1) ** 2) + assert np.allclose(f(-0.5), -(-0.5 + 1)) + assert np.allclose(f(-2.5), (-2.5 + 1)) class TestCatalystGrad: From fd03214b4e66af8bc5af5bd2488c92b7fde17f62 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sun, 21 Jul 2024 00:20:21 -0400 Subject: [PATCH 10/20] test --- pennylane/ops/op_math/condition.py | 7 +++++-- tests/ops/op_math/test_condition.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index d6de3de58bf..a64078a2b4b 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -498,13 +498,16 @@ def qnode(a, x, y, z): if true_fn is None: raise TypeError( "cond missing 1 required positional argument: 'true_fn'.\n" - "Note that if the conditional includes a mid-circuit measurement," + "Note that if the conditional includes a mid-circuit measurement, " "qml.cond cannot be used as a decorator.\n" "Instead, please use the form qml.cond(condition, true_fn, false_fn)." ) if elifs: - raise ConditionalTransformError("'elif' branches are not supported in interpreted mode.") + raise ConditionalTransformError( + "'elif' branches are not supported when not using @qjit and the " + "conditional include mid-circuit measurements." + ) if callable(true_fn): # We assume that the callable is an operation or a quantum function diff --git a/tests/ops/op_math/test_condition.py b/tests/ops/op_math/test_condition.py index b8f852d4569..2db00e49fc5 100644 --- a/tests/ops/op_math/test_condition.py +++ b/tests/ops/op_math/test_condition.py @@ -607,6 +607,21 @@ def conditional_false_fn(y): # pylint: disable=unused-variable assert np.allclose(f(-0.5), -(-0.5 + 1)) assert np.allclose(f(-2.5), (-2.5 + 1)) + def test_error_mcms_elif(self): + """Test that an error is raised if elifs are provided + when the conditional includes an MCM""" + dev = qml.device("default.qubit") + + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + m = qml.measure(0) + qml.cond(m, qml.RX, elifs=[(~m, qml.RY)]) + return qml.probs + + with pytest.raises(ConditionalTransformError, match="'elif' branches are not supported"): + circuit(0.5) + def test_error_no_true_fn(self): """Test that an error is raised if no true_fn is provided when the conditional includes an MCM""" From 9f9acb757c6f4f2380ee987ee604881ce2efa367 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sun, 21 Jul 2024 00:28:57 -0400 Subject: [PATCH 11/20] lint --- pennylane/ops/op_math/condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index a64078a2b4b..ea548c795a7 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -474,7 +474,7 @@ def qnode(a, x, y, z): ops_loader = available_eps[active_jit]["ops"].load() if true_fn is None: - return lambda fn: ops_loader.cond(condition)(fn) + return ops_loader.cond(condition) cond_func = ops_loader.cond(condition)(true_fn) From c2aee7dce27b5447743d02739e7108746e3761d7 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sun, 21 Jul 2024 00:32:47 -0400 Subject: [PATCH 12/20] isort --- pennylane/ops/op_math/condition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index ea548c795a7..ee7ed86947d 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -19,8 +19,8 @@ from pennylane import QueuingManager from pennylane.compiler import compiler -from pennylane.operation import AnyWires, Operation, Operator from pennylane.measurements import MeasurementValue +from pennylane.operation import AnyWires, Operation, Operator from pennylane.ops.op_math.symbolicop import SymbolicOp from pennylane.tape import make_qscript @@ -202,7 +202,7 @@ def __call__(self, *args, **kwargs): if pred: return branch_fn(*args, **kwargs) - return self.false_fn(*args, **kwargs) + return self.false_fn(*args, **kwargs) # pylint: disable=not-callable def cond(condition, true_fn=None, false_fn=None, elifs=()): From ba83c8d668d2c032ff1f33cc72a66e1bd5ad26d3 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sun, 21 Jul 2024 00:48:25 -0400 Subject: [PATCH 13/20] remove device test --- tests/test_device.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/test_device.py b/tests/test_device.py index 1639f76ed57..8ab79159072 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -575,21 +575,6 @@ def test_mcm_unsupported_error(self, mock_device_with_paulis_and_methods): with pytest.raises(DeviceError, match="Mid-circuit measurements are not natively"): dev.check_validity(tape.operations, tape.observables) - def test_conditional_ops_unsupported_error(self, mock_device_with_paulis_and_methods): - """Test that an error is raised for conditional operations if - mid-circuit measurements are not supported natively""" - dev = mock_device_with_paulis_and_methods(wires=2) - - with qml.queuing.AnnotatedQueue() as q: - qml.cond(0, qml.RY)(0.3, wires=0) - qml.PauliZ(0) - - tape = qml.tape.QuantumScript.from_queue(q) - # Raises an error for device that doesn't support conditional - # operations natively - with pytest.raises(DeviceError, match="Gate Conditional\\(RY\\) not supported on device"): - dev.check_validity(tape.operations, tape.observables) - @pytest.mark.parametrize( "wires, subset, expected_subset", [ From f31d153169678f78e86f20d7898336670f57254e Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Mon, 22 Jul 2024 20:17:12 -0400 Subject: [PATCH 14/20] Apply suggestions from code review Co-authored-by: Christina Lee --- pennylane/compiler/qjit_api.py | 30 +++++++++++++++++++++++++++++- tests/test_compiler.py | 2 +- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index 9cefcb11445..d9a614bdb13 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -379,7 +379,18 @@ def loop_rx(x): # if there is no active compiler, simply interpret the while loop # via the Python interpretor. - def _decorator(body_fn): + def _decorator(body_fn: Callable) -> Callable: + """Transform that will call the input ``body_fn`` until the closure variable ``cond_fn`` is met. + + Args: + body_fn (Callable): + + Closure Variables: + cond_fn (Callable): + + Returns: + Callable: a callable with the same signature as ``body_fn`` and ``cond_fn``. + """ return WhileLoopCallable(cond_fn, body_fn) return _decorator @@ -497,6 +508,23 @@ def loop_rx(i, x): # if there is no active compiler, simply interpret the for loop # via the Python interpretor. def _decorator(body_fn): + """Transform that will call the input ``body_fn`` within the for loop. + + Args: + body_fn (Callable): The function called within the for loop. Note that the loop body + function must always have the iteration index as its first + argument, which can be used arbitrarily inside the loop body. As the value of the index + across iterations is handled automatically by the provided loop bounds, it must not be + returned from the function. + + Closure Variables: + lower_bound (int): starting value of the iteration index + upper_bound (int): (exclusive) upper bound of the iteration index + step (int): increment applied to the iteration index at the end of each iteration + + Returns: + Callable: a callable with the same signature as ``body_fn`` +""" return ForLoopCallable(lower_bound, upper_bound, step, body_fn) return _decorator diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 4bcf5b4a150..a0e72a1b0df 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -472,7 +472,7 @@ def inner(j): qml.RX(0.7, wires=[2]), ] - assert [qml.equal(i, j) for i, j in zip(res, expected)] + _ = [qml.assert_equal(i, j) for i, j in zip(res, expected)] def test_cond(self): """Test condition with simple true_fn""" From ee2532242b9d6da4f6e1c37125519e9fecf94a5b Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Mon, 22 Jul 2024 20:18:01 -0400 Subject: [PATCH 15/20] Update pennylane/compiler/qjit_api.py --- pennylane/compiler/qjit_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index d9a614bdb13..2df97613cb1 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -508,7 +508,7 @@ def loop_rx(i, x): # if there is no active compiler, simply interpret the for loop # via the Python interpretor. def _decorator(body_fn): - """Transform that will call the input ``body_fn`` within the for loop. + """Transform that will call the input ``body_fn`` within a for loop defined by the closure variables lower_bound, upper_bound, and step. Args: body_fn (Callable): The function called within the for loop. Note that the loop body From 01428a590d0d931c970c65b5c16719721e3bddb9 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Mon, 22 Jul 2024 20:22:06 -0400 Subject: [PATCH 16/20] Update pennylane/ops/op_math/condition.py --- pennylane/ops/op_math/condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index ee7ed86947d..575a414f94d 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -110,7 +110,7 @@ class CondCallable: # pylint:disable=too-few-public-methods false_fn (callable): The function to apply if ``condition`` is ``False`` elifs (List(Tuple(bool, callable))): A list of (bool, elif_fn) clauses. - Passing ``true_fn``, ``false_fn``, and ``elifs`` on initialization + Passing ``false_fn`` and ``elifs`` on initialization is optional; these functions can be registered post-initialization via decorators: From ab489abd3a6f7a82a5b96ecb97d7b798abc87b29 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Tue, 23 Jul 2024 11:13:25 -0400 Subject: [PATCH 17/20] suggested changes --- pennylane/compiler/qjit_api.py | 75 +++++++++++++++++++--------------- tests/test_compiler.py | 23 +++++++++-- 2 files changed, 61 insertions(+), 37 deletions(-) diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index 2df97613cb1..8d26b37422d 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """QJIT compatible quantum and compilation operations API""" +from collections.abc import Callable from .compiler import ( AvailableCompilers, @@ -380,17 +381,17 @@ def loop_rx(x): # if there is no active compiler, simply interpret the while loop # via the Python interpretor. def _decorator(body_fn: Callable) -> Callable: - """Transform that will call the input ``body_fn`` until the closure variable ``cond_fn`` is met. - - Args: - body_fn (Callable): - - Closure Variables: - cond_fn (Callable): - - Returns: - Callable: a callable with the same signature as ``body_fn`` and ``cond_fn``. - """ + """Transform that will call the input ``body_fn`` until the closure variable ``cond_fn`` is met. + + Args: + body_fn (Callable): + + Closure Variables: + cond_fn (Callable): + + Returns: + Callable: a callable with the same signature as ``body_fn`` and ``cond_fn``. + """ return WhileLoopCallable(cond_fn, body_fn) return _decorator @@ -422,16 +423,9 @@ def __call__(self, *init_state): def for_loop(lower_bound, upper_bound, step): - """A :func:`~.qjit` compatible for-loop for PennyLane programs. - - .. note:: - - This function only supports the Catalyst compiler. See - :func:`catalyst.for_loop` for more details. - - Please see the Catalyst :doc:`quickstart guide `, - as well as the :doc:`sharp bits and debugging tips ` - page for an overview of the differences between Catalyst and PennyLane. + """A :func:`~.qjit` compatible for-loop for PennyLane programs. When + used without :func:`~.qjit`, this function will fall back to a standard + Python for loop. This decorator provides a functional version of the traditional for-loop, similar to `jax.cond.fori_loop `__. @@ -479,7 +473,6 @@ def for_loop(lower_bound, upper_bound, step, loop_fn, *args): dev = qml.device("lightning.qubit", wires=1) - @qml.qjit @qml.qnode(dev) def circuit(n: int, x: float): @@ -494,10 +487,24 @@ def loop_rx(i, x): # apply the for loop final_x = loop_rx(x) - return qml.expval(qml.Z(0)), final_x + return qml.expval(qml.Z(0)) >>> circuit(7, 1.6) - (array(0.97926626), array(0.55395718)) + array(0.97926626) + + ``for_loop`` is also :func:`~.qjit` compatible; when used with the + :func:`~.qjit` decorator, the for loop will not be unrolled, and instead + will be captured as-is during compilation and executed during runtime: + + >>> qml.qjit(circuit)(7, 1.6) + Array(0.97926626, dtype=float64) + + .. note:: + + Please see the Catalyst :doc:`quickstart guide `, + as well as the :doc:`sharp bits and debugging tips ` + page for an overview of using quantum just-in-time compilation. + """ if active_jit := active_compiler(): @@ -508,23 +515,23 @@ def loop_rx(i, x): # if there is no active compiler, simply interpret the for loop # via the Python interpretor. def _decorator(body_fn): - """Transform that will call the input ``body_fn`` within a for loop defined by the closure variables lower_bound, upper_bound, and step. - - Args: - body_fn (Callable): The function called within the for loop. Note that the loop body + """Transform that will call the input ``body_fn`` within a for loop defined by the closure variables lower_bound, upper_bound, and step. + + Args: + body_fn (Callable): The function called within the for loop. Note that the loop body function must always have the iteration index as its first argument, which can be used arbitrarily inside the loop body. As the value of the index across iterations is handled automatically by the provided loop bounds, it must not be returned from the function. - - Closure Variables: + + Closure Variables: lower_bound (int): starting value of the iteration index upper_bound (int): (exclusive) upper bound of the iteration index step (int): increment applied to the iteration index at the end of each iteration - - Returns: - Callable: a callable with the same signature as ``body_fn`` -""" + + Returns: + Callable: a callable with the same signature as ``body_fn`` + """ return ForLoopCallable(lower_bound, upper_bound, step, body_fn) return _decorator diff --git a/tests/test_compiler.py b/tests/test_compiler.py index a0e72a1b0df..5b052ed4d1d 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -371,6 +371,23 @@ def inner(j): assert f(5, 6) == 30 # 5 * 6 assert f(4, 7) == 28 # 4 * 7 + def test_fallback_while_loop_qnode(self): + """Test that qml.while_loop inside a qnode fallsback to + Python without qjit""" + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev) + def circuit(n): + @qml.while_loop(lambda v: v[0] < v[1]) + def loop(v): + qml.PauliX(wires=0) + return v[0] + 1, v[1] + + loop((0, n)) + return qml.expval(qml.PauliZ(0)) + + assert jnp.allclose(circuit(1), -1.0) + def test_dynamic_wires_for_loops(self): """Test for loops with iteration index-dependant wires.""" dev = qml.device("lightning.qubit", wires=6) @@ -426,7 +443,7 @@ def inner(j): def test_for_loop_python_fallback(self): """Test that qml.for_loop fallsback to Python interpretation if Catalyst is not available""" - dev = qml.device("lightning.qubit", wires=2) + dev = qml.device("lightning.qubit", wires=3) @qml.qnode(dev) def circuit(x, n): @@ -449,7 +466,7 @@ def inner(j): inner() - return jnp.sin(x) + return x + 0.1 loop_fn() loop_fn_returns(x) @@ -457,7 +474,7 @@ def inner(j): return qml.expval(qml.PauliZ(0)) x = 0.5 - assert jnp.allclose(circuit(x, 2), qml.qjit(circuit)(x, 2)) + assert jnp.allclose(circuit(x, 3), qml.qjit(circuit)(x, 3)) res = circuit.tape.operations expected = [ From f833dd33c64aeb8a2589b53b8ea0e3c7f8e2d509 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Tue, 23 Jul 2024 11:15:38 -0400 Subject: [PATCH 18/20] suggested changes - test --- tests/test_compiler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 5b052ed4d1d..92130d5f3fd 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -388,6 +388,10 @@ def loop(v): assert jnp.allclose(circuit(1), -1.0) + res = circuit.tape.operations + expected = [qml.PauliX(0) for i in range(4)] + _ = [qml.assert_equal(i, j) for i, j in zip(res, expected)] + def test_dynamic_wires_for_loops(self): """Test for loops with iteration index-dependant wires.""" dev = qml.device("lightning.qubit", wires=6) From 99a1454042e8b425f2d8628d51a2904018b6a501 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Thu, 1 Aug 2024 16:54:25 -0400 Subject: [PATCH 19/20] Incorporate the 'captured' and 'non-captured' versions of `qml.cond` into `CondCallable` (#6063) **Context:** For 'historical reasons', the captured version (see sc-66774) and the non-captured version (see sc-69157) of `qml.cond` have a different style, which is rather uneven. The purpose of this story is to incorporate them in such a way that the implementation is more similar to the one in Catalyst. Or, to make another example, the implementation of `qml.for_loop` in PL (see sc-66736 and sc-69432). The purpose is not to change/add functionalities, but rather to unify the code structure so that it is more elegant. **Description of the Change:** As above. **Benefits:** Better and cleaner structure, more similar to Catalyst. **Possible Drawbacks:** None that I can think of right now directly due to this PR (we just moved some code). **Related GitHub Issues:** None **Related Shortcut Stories:** [sc-70342] --- pennylane/ops/op_math/condition.py | 120 ++++++++++++++++------------- tests/capture/test_capture_cond.py | 24 +++--- 2 files changed, 77 insertions(+), 67 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 830c2ca6424..c9335152e1d 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -142,10 +142,14 @@ def __init__(self, condition, true_fn, false_fn=None, elifs=()): self.branch_fns = [true_fn] self.otherwise_fn = false_fn - if false_fn is None: + # when working with `qml.capture.enabled()`, + # it's easier to store the original `elifs` argument + self.orig_elifs = elifs + + if false_fn is None and not qml.capture.enabled(): self.otherwise_fn = lambda *args, **kwargs: None - if elifs: + if elifs and not qml.capture.enabled(): elif_preds, elif_fns = list(zip(*elifs)) self.preds.extend(elif_preds) self.branch_fns.extend(elif_fns) @@ -198,7 +202,7 @@ def elifs(self): """(List(Tuple(bool, callable))): a list of (bool, elif_fn) clauses""" return list(zip(self.preds[1:], self.branch_fns[1:])) - def __call__(self, *args, **kwargs): + def __call_capture_disabled(self, *args, **kwargs): # python fallback for pred, branch_fn in zip(self.preds, self.branch_fns): if pred: @@ -206,6 +210,65 @@ def __call__(self, *args, **kwargs): return self.false_fn(*args, **kwargs) # pylint: disable=not-callable + def __call_capture_enabled(self, *args, **kwargs): + + import jax # pylint: disable=import-outside-toplevel + + cond_prim = _get_cond_qfunc_prim() + + elifs = ( + (self.orig_elifs,) + if len(self.orig_elifs) > 0 and not isinstance(self.orig_elifs[0], tuple) + else self.orig_elifs + ) + + @wraps(self.true_fn) + def new_wrapper(*args, **kwargs): + + jaxpr_true = jax.make_jaxpr(functools.partial(self.true_fn, **kwargs))(*args) + jaxpr_false = ( + jax.make_jaxpr(functools.partial(self.otherwise_fn, **kwargs))(*args) + if self.otherwise_fn + else None + ) + + # We extract each condition (or predicate) from the elifs argument list + # since these are traced by JAX and are passed as positional arguments to the primitive + elifs_conditions = [] + jaxpr_elifs = [] + + for pred, elif_fn in elifs: + elifs_conditions.append(pred) + jaxpr_elifs.append(jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args)) + + conditions = jax.numpy.array([self.condition, *elifs_conditions, True]) + + jaxpr_branches = [jaxpr_true, *jaxpr_elifs, jaxpr_false] + jaxpr_consts = [jaxpr.consts if jaxpr is not None else () for jaxpr in jaxpr_branches] + + # We need to flatten the constants since JAX does not allow + # to pass lists as positional arguments + consts_flat = [const for sublist in jaxpr_consts for const in sublist] + n_consts_per_branch = [len(consts) for consts in jaxpr_consts] + + return cond_prim.bind( + conditions, + *args, + *consts_flat, + jaxpr_branches=jaxpr_branches, + n_consts_per_branch=n_consts_per_branch, + n_args=len(args), + ) + + return new_wrapper(*args, **kwargs) + + def __call__(self, *args, **kwargs): + + if qml.capture.enabled(): + return self.__call_capture_enabled(*args, **kwargs) + + return self.__call_capture_disabled(*args, **kwargs) + def cond(condition, true_fn: Callable = None, false_fn: Optional[Callable] = None, elifs=()): """Quantum-compatible if-else conditionals --- condition quantum operations @@ -501,9 +564,6 @@ def qnode(a, x, y, z): return cond_func - if qml.capture.enabled(): - return _capture_cond(condition, true_fn, false_fn, elifs) - if not isinstance(condition, MeasurementValue): # The condition is not a mid-circuit measurement. if true_fn is None: @@ -646,51 +706,3 @@ def _(*_, jaxpr_branches, **__): return outvals_true return cond_prim - - -def _capture_cond(condition, true_fn, false_fn=None, elifs=()) -> Callable: - """Capture compatible way to apply conditionals.""" - - import jax # pylint: disable=import-outside-toplevel - - cond_prim = _get_cond_qfunc_prim() - - elifs = (elifs,) if len(elifs) > 0 and not isinstance(elifs[0], tuple) else elifs - - @wraps(true_fn) - def new_wrapper(*args, **kwargs): - - jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args) - jaxpr_false = ( - jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None - ) - - # We extract each condition (or predicate) from the elifs argument list - # since these are traced by JAX and are passed as positional arguments to the primitive - elifs_conditions = [] - jaxpr_elifs = [] - - for pred, elif_fn in elifs: - elifs_conditions.append(pred) - jaxpr_elifs.append(jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args)) - - conditions = jax.numpy.array([condition, *elifs_conditions, True]) - - jaxpr_branches = [jaxpr_true, *jaxpr_elifs, jaxpr_false] - jaxpr_consts = [jaxpr.consts if jaxpr is not None else () for jaxpr in jaxpr_branches] - - # We need to flatten the constants since JAX does not allow - # to pass lists as positional arguments - consts_flat = [const for sublist in jaxpr_consts for const in sublist] - n_consts_per_branch = [len(consts) for consts in jaxpr_consts] - - return cond_prim.bind( - conditions, - *args, - *consts_flat, - jaxpr_branches=jaxpr_branches, - n_consts_per_branch=n_consts_per_branch, - n_args=len(args), - ) - - return new_wrapper diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index a6ddfa63dce..a3a70486157 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -22,7 +22,7 @@ import pytest import pennylane as qml -from pennylane.ops.op_math.condition import _capture_cond +from pennylane.ops.op_math.condition import CondCallable pytestmark = pytest.mark.jax @@ -229,7 +229,7 @@ class TestCondReturns: def test_validate_mismatches(self, true_fn, false_fn, expected_error, match): """Test mismatch in number and type of output variables.""" with pytest.raises(expected_error, match=match): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + jax.make_jaxpr(CondCallable(True, true_fn, false_fn))(jax.numpy.array(1)) def test_validate_number_of_output_variables(self): """Test mismatch in number of output variables.""" @@ -241,7 +241,7 @@ def false_fn(x): return x + 1 with pytest.raises(ValueError, match=r"Mismatch in number of output variables"): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + jax.make_jaxpr(CondCallable(True, true_fn, false_fn))(jax.numpy.array(1)) def test_validate_output_variable_types(self): """Test mismatch in output variable types.""" @@ -253,7 +253,7 @@ def false_fn(x): return x + 1, x + 2.0 with pytest.raises(ValueError, match=r"Mismatch in output abstract values"): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + jax.make_jaxpr(CondCallable(True, true_fn, false_fn))(jax.numpy.array(1)) def test_validate_no_false_branch_with_return(self): """Test no false branch provided with return variables.""" @@ -265,7 +265,7 @@ def true_fn(x): ValueError, match=r"The false branch must be provided if the true branch returns any variables", ): - jax.make_jaxpr(_capture_cond(True, true_fn))(jax.numpy.array(1)) + jax.make_jaxpr(CondCallable(True, true_fn))(jax.numpy.array(1)) def test_validate_no_false_branch_with_return_2(self): """Test no false branch provided with return variables.""" @@ -280,9 +280,7 @@ def elif_fn(x): ValueError, match=r"The false branch must be provided if the true branch returns any variables", ): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn=None, elifs=(False, elif_fn)))( - jax.numpy.array(1) - ) + jax.make_jaxpr(CondCallable(True, true_fn, elifs=[(True, elif_fn)]))(jax.numpy.array(1)) def test_validate_elif_branches(self): """Test elif branch mismatches.""" @@ -306,13 +304,13 @@ def elif_fn3(x): ValueError, match=r"Mismatch in output abstract values in elif branch #1" ): jax.make_jaxpr( - _capture_cond(False, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)]) + CondCallable(True, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)]) )(jax.numpy.array(1)) with pytest.raises( ValueError, match=r"Mismatch in number of output variables in elif branch #0" ): - jax.make_jaxpr(_capture_cond(False, true_fn, false_fn, [(True, elif_fn3)]))( + jax.make_jaxpr(CondCallable(True, true_fn, false_fn, elifs=[(True, elif_fn3)]))( jax.numpy.array(1) ) @@ -328,7 +326,7 @@ def true_fn(): qml.RX(0.1, wires=0) qml.cond(pred > 0, true_fn)() - return qml.expval(qml.PauliZ(wires=0)) + return qml.expval(qml.Z(wires=0)) @qml.qnode(dev) @@ -352,7 +350,7 @@ def elif_fn1(arg1, arg2): qml.cond(pred > 0, true_fn, false_fn, elifs=(pred == -1, elif_fn1))(arg1, arg2) qml.RX(0.10, wires=0) - return qml.expval(qml.PauliZ(wires=0)) + return qml.expval(qml.Z(wires=0)) @qml.qnode(dev) @@ -371,7 +369,7 @@ def false_fn(arg1, arg2): qml.cond(pred > 0, true_fn, false_fn)(arg1, arg2) qml.RX(0.10, wires=0) - return qml.expval(qml.PauliZ(wires=0)) + return qml.expval(qml.Z(wires=0)) @qml.qnode(dev) From a659340364e73569a67aed9db71cc7adca87fb8f Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Thu, 1 Aug 2024 21:57:16 -0400 Subject: [PATCH 20/20] suggested change --- pennylane/ops/op_math/condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index c9335152e1d..60380d3ad01 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -307,7 +307,7 @@ def cond(condition, true_fn: Callable = None, false_fn: Optional[Callable] = Non If a branch returns one or more variables, every other branch must return the same abstract values. Args: - condition (Union[.MeasurementValue, bool]): a conditional expression involving a mid-circuit + condition (Union[.MeasurementValue, bool]): a conditional expression that may involve a mid-circuit measurement value (see :func:`.pennylane.measure`). true_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``True``