diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 40e559447e1..51e2fb5f36f 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -6,6 +6,9 @@
Improvements 🛠
+* Improve unit testing for capturing of nested control flows.
+ [(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111)
+
* Some custom primitives for the capture project can now be imported via
`from pennylane.capture.primitives import *`.
[(#6129)](https://github.com/PennyLaneAI/pennylane/pull/6129)
@@ -21,9 +24,14 @@
* Fix Pytree serialization of operators with empty shot vectors:
[(#6155)](https://github.com/PennyLaneAI/pennylane/pull/6155)
+* Fix `qml.PrepSelPrep` template to work with `torch`:
+ [(#6191)](https://github.com/PennyLaneAI/pennylane/pull/6191)
+
Contributors ✍️
This release contains contributions from (in alphabetical order):
+Utkarsh Azad
Jack Brown
Christina Lee
+William Maxwell
diff --git a/pennylane/_version.py b/pennylane/_version.py
index e94ae8e0a64..3c61a75803b 100644
--- a/pennylane/_version.py
+++ b/pennylane/_version.py
@@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""
-__version__ = "0.39.0-dev4"
+__version__ = "0.39.0-dev6"
diff --git a/pennylane/templates/subroutines/prepselprep.py b/pennylane/templates/subroutines/prepselprep.py
index 53df7f96cfc..c9796427615 100644
--- a/pennylane/templates/subroutines/prepselprep.py
+++ b/pennylane/templates/subroutines/prepselprep.py
@@ -23,22 +23,15 @@
def _get_new_terms(lcu):
"""Compute a new sum of unitaries with positive coefficients"""
-
- new_coeffs = []
+ coeffs, ops = lcu.terms()
+ angles = qml.math.angle(coeffs)
new_ops = []
- for coeff, op in zip(*lcu.terms()):
-
- angle = qml.math.angle(coeff)
- new_coeffs.append(qml.math.abs(coeff))
-
+ for angle, op in zip(angles, ops):
new_op = op @ qml.GlobalPhase(-angle, wires=op.wires)
new_ops.append(new_op)
- interface = qml.math.get_interface(lcu.terms()[0])
- new_coeffs = qml.math.array(new_coeffs, like=interface)
-
- return new_coeffs, new_ops
+ return qml.math.abs(coeffs), new_ops
class PrepSelPrep(Operation):
diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py
index bed3d848e55..7a92107e264 100644
--- a/tests/capture/test_capture_cond.py
+++ b/tests/capture/test_capture_cond.py
@@ -703,6 +703,73 @@ def f(*x):
assert np.allclose(res, expected, atol=atol, rtol=0), f"Expected {expected}, but got {res}"
+ @pytest.mark.parametrize("upper_bound, arg", [(3, [0.1, 0.3, 0.5]), (2, [2, 7, 12])])
+ def test_nested_cond_for_while_loop(self, upper_bound, arg):
+ """Test that a nested control flows are correctly captured into a jaxpr."""
+
+ dev = qml.device("default.qubit", wires=3)
+
+ # Control flow for qml.conds
+ def true_fn(_):
+ @qml.for_loop(0, upper_bound, 1)
+ def loop_fn(i):
+ qml.Hadamard(wires=i)
+
+ loop_fn()
+
+ def elif_fn(arg):
+ qml.RY(arg**2, wires=[2])
+
+ def false_fn(arg):
+ qml.RY(-arg, wires=[2])
+
+ @qml.qnode(dev)
+ def circuit(upper_bound, arg):
+ qml.RY(-np.pi / 2, wires=[2])
+ m_0 = qml.measure(2)
+
+ # NOTE: qml.cond(m_0, qml.RX)(arg[1], wires=1) doesn't work
+ def rx_fn():
+ qml.RX(arg[1], wires=1)
+
+ qml.cond(m_0, rx_fn)()
+
+ def ry_fn():
+ qml.RY(arg[1] ** 3, wires=1)
+
+ # nested for loops.
+ # outer for loop updates x
+ @qml.for_loop(0, upper_bound, 1)
+ def loop_fn_returns(i, x):
+ qml.RX(x, wires=i)
+ m_1 = qml.measure(0)
+ # NOTE: qml.cond(m_0, qml.RY)(arg[1], wires=1) doesn't work
+ qml.cond(m_1, ry_fn)()
+
+ # inner while loop
+ @qml.while_loop(lambda j: j < upper_bound)
+ def inner(j):
+ qml.RZ(j, wires=0)
+ qml.RY(x**2, wires=0)
+ m_2 = qml.measure(0)
+ qml.cond(m_2, true_fn=true_fn, false_fn=false_fn, elifs=((m_1, elif_fn)))(
+ arg[0]
+ )
+ return j + 1
+
+ inner(i + 1)
+ return x + 0.1
+
+ loop_fn_returns(arg[2])
+
+ return qml.expval(qml.Z(0))
+
+ args = [upper_bound, arg]
+ result = circuit(*args)
+ jaxpr = jax.make_jaxpr(circuit)(*args)
+ res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, upper_bound, *arg)
+ assert np.allclose(result, res_ev_jxpr), f"Expected {result}, but got {res_ev_jxpr}"
+
class TestPytree:
"""Test pytree support for cond."""
diff --git a/tests/capture/test_capture_for_loop.py b/tests/capture/test_capture_for_loop.py
index 64671a295f3..d27c723e218 100644
--- a/tests/capture/test_capture_for_loop.py
+++ b/tests/capture/test_capture_for_loop.py
@@ -372,6 +372,51 @@ def inner(j):
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"
+ @pytest.mark.parametrize(
+ "upper_bound, arg, expected", [(3, 0.5, 0.00223126), (2, 12, 0.2653001)]
+ )
+ def test_nested_for_and_while_loop(self, upper_bound, arg, expected):
+ """Test that a nested for loop and while loop is correctly captured into a jaxpr."""
+
+ dev = qml.device("default.qubit", wires=3)
+
+ @qml.qnode(dev)
+ def circuit(upper_bound, arg):
+
+ # for loop with dynamic bounds
+ @qml.for_loop(0, upper_bound, 1)
+ def loop_fn(i):
+ qml.Hadamard(wires=i)
+
+ # nested for-while loops.
+ @qml.for_loop(0, upper_bound, 1)
+ def loop_fn_returns(i, x):
+ qml.RX(x, wires=i)
+
+ # inner while loop
+ @qml.while_loop(lambda j: j < upper_bound)
+ def inner(j):
+ qml.RZ(j, wires=0)
+ qml.RY(x**2, wires=0)
+ return j + 1
+
+ inner(i + 1)
+
+ return x + 0.1
+
+ loop_fn()
+ loop_fn_returns(arg)
+
+ return qml.expval(qml.Z(0))
+
+ args = [upper_bound, arg]
+ result = circuit(*args)
+ assert np.allclose(result, expected), f"Expected {expected}, but got {result}"
+
+ jaxpr = jax.make_jaxpr(circuit)(*args)
+ res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
+ assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"
+
def test_pytree_inputs():
"""Test that for_loop works with pytree inputs and outputs."""
diff --git a/tests/capture/test_capture_while_loop.py b/tests/capture/test_capture_while_loop.py
index 33e9466ab78..d87f6299ba7 100644
--- a/tests/capture/test_capture_while_loop.py
+++ b/tests/capture/test_capture_while_loop.py
@@ -219,6 +219,43 @@ def inner(j):
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"
+ @pytest.mark.parametrize("upper_bound, arg", [(3, 0.5), (2, 12)])
+ def test_while_and_for_loop_nested(self, upper_bound, arg):
+ """Test that a nested while and for loop is correctly captured into a jaxpr."""
+
+ dev = qml.device("default.qubit", wires=3)
+
+ def ry_fn(arg):
+ qml.RY(arg, wires=1)
+
+ @qml.qnode(dev)
+ def circuit(upper_bound, arg):
+
+ # while loop with dynamic bounds
+ @qml.while_loop(lambda i: i < upper_bound)
+ def loop_fn(i):
+ qml.Hadamard(wires=i)
+
+ @qml.for_loop(0, i, 1)
+ def loop_fn_returns(i, x):
+ qml.RX(x, wires=i)
+ m_0 = qml.measure(0)
+ qml.cond(m_0, ry_fn)(x)
+ return i + 1
+
+ loop_fn_returns(arg)
+ return i + 1
+
+ loop_fn(0)
+
+ return qml.expval(qml.Z(0))
+
+ args = [upper_bound, arg]
+ result = circuit(*args)
+ jaxpr = jax.make_jaxpr(circuit)(*args)
+ res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
+ assert np.allclose(result, res_ev_jxpr), f"Expected {result}, but got {res_ev_jxpr}"
+
def test_pytree_input_output():
"""Test that the while loop supports pytree input and output."""
diff --git a/tests/templates/test_subroutines/test_prepselprep.py b/tests/templates/test_subroutines/test_prepselprep.py
index 82629973865..95e7f771ef7 100644
--- a/tests/templates/test_subroutines/test_prepselprep.py
+++ b/tests/templates/test_subroutines/test_prepselprep.py
@@ -315,6 +315,26 @@ class TestInterfaces:
params = np.array([0.4, 0.5, 0.1, 0.3])
exp_grad = [0.41177732, -0.21262349, 1.6437038, -0.74256516]
+ @pytest.mark.torch
+ def test_torch(self):
+ """Test the torch interface"""
+ import torch
+
+ dev = qml.device("default.qubit")
+
+ @qml.qnode(dev)
+ def circuit(coeffs):
+ H = qml.ops.LinearCombination(
+ coeffs, [qml.Y(0), qml.Y(1) @ qml.Y(2), qml.X(0), qml.X(1) @ qml.X(2)]
+ )
+ qml.PrepSelPrep(H, control=(3, 4))
+ return qml.expval(qml.PauliZ(3) @ qml.PauliZ(4))
+
+ params = torch.tensor(self.params)
+ res = torch.autograd.functional.jacobian(circuit, params)
+ assert qml.math.shape(res) == (4,)
+ assert np.allclose(res, self.exp_grad, atol=1e-5)
+
@pytest.mark.autograd
def test_autograd(self):
"""Test the autograd interface"""