From 46e10365f2308cce466683f436267de11d466f27 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Fri, 30 Aug 2024 20:21:24 -0400 Subject: [PATCH 1/4] Test support for capturing nested control flows (#6111) **Context:** Adds test for asserting correct support for capturing nested control flows **Description of the Change:** Adds new tests **Benefits:** **Possible Drawbacks:** N/A **Related GitHub Issues:** [sc-66776] --- doc/releases/changelog-dev.md | 4 ++ tests/capture/test_capture_cond.py | 67 ++++++++++++++++++++++++ tests/capture/test_capture_for_loop.py | 45 ++++++++++++++++ tests/capture/test_capture_while_loop.py | 37 +++++++++++++ 4 files changed, 153 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 40e559447e1..b859ce0f156 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -6,6 +6,9 @@

Improvements 🛠

+* Improve unit testing for capturing of nested control flows. + [(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111) + * Some custom primitives for the capture project can now be imported via `from pennylane.capture.primitives import *`. [(#6129)](https://github.com/PennyLaneAI/pennylane/pull/6129) @@ -25,5 +28,6 @@ This release contains contributions from (in alphabetical order): +Utkarsh Azad Jack Brown Christina Lee diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index bed3d848e55..7a92107e264 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -703,6 +703,73 @@ def f(*x): assert np.allclose(res, expected, atol=atol, rtol=0), f"Expected {expected}, but got {res}" + @pytest.mark.parametrize("upper_bound, arg", [(3, [0.1, 0.3, 0.5]), (2, [2, 7, 12])]) + def test_nested_cond_for_while_loop(self, upper_bound, arg): + """Test that a nested control flows are correctly captured into a jaxpr.""" + + dev = qml.device("default.qubit", wires=3) + + # Control flow for qml.conds + def true_fn(_): + @qml.for_loop(0, upper_bound, 1) + def loop_fn(i): + qml.Hadamard(wires=i) + + loop_fn() + + def elif_fn(arg): + qml.RY(arg**2, wires=[2]) + + def false_fn(arg): + qml.RY(-arg, wires=[2]) + + @qml.qnode(dev) + def circuit(upper_bound, arg): + qml.RY(-np.pi / 2, wires=[2]) + m_0 = qml.measure(2) + + # NOTE: qml.cond(m_0, qml.RX)(arg[1], wires=1) doesn't work + def rx_fn(): + qml.RX(arg[1], wires=1) + + qml.cond(m_0, rx_fn)() + + def ry_fn(): + qml.RY(arg[1] ** 3, wires=1) + + # nested for loops. + # outer for loop updates x + @qml.for_loop(0, upper_bound, 1) + def loop_fn_returns(i, x): + qml.RX(x, wires=i) + m_1 = qml.measure(0) + # NOTE: qml.cond(m_0, qml.RY)(arg[1], wires=1) doesn't work + qml.cond(m_1, ry_fn)() + + # inner while loop + @qml.while_loop(lambda j: j < upper_bound) + def inner(j): + qml.RZ(j, wires=0) + qml.RY(x**2, wires=0) + m_2 = qml.measure(0) + qml.cond(m_2, true_fn=true_fn, false_fn=false_fn, elifs=((m_1, elif_fn)))( + arg[0] + ) + return j + 1 + + inner(i + 1) + return x + 0.1 + + loop_fn_returns(arg[2]) + + return qml.expval(qml.Z(0)) + + args = [upper_bound, arg] + result = circuit(*args) + jaxpr = jax.make_jaxpr(circuit)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, upper_bound, *arg) + assert np.allclose(result, res_ev_jxpr), f"Expected {result}, but got {res_ev_jxpr}" + class TestPytree: """Test pytree support for cond.""" diff --git a/tests/capture/test_capture_for_loop.py b/tests/capture/test_capture_for_loop.py index 64671a295f3..d27c723e218 100644 --- a/tests/capture/test_capture_for_loop.py +++ b/tests/capture/test_capture_for_loop.py @@ -372,6 +372,51 @@ def inner(j): res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + @pytest.mark.parametrize( + "upper_bound, arg, expected", [(3, 0.5, 0.00223126), (2, 12, 0.2653001)] + ) + def test_nested_for_and_while_loop(self, upper_bound, arg, expected): + """Test that a nested for loop and while loop is correctly captured into a jaxpr.""" + + dev = qml.device("default.qubit", wires=3) + + @qml.qnode(dev) + def circuit(upper_bound, arg): + + # for loop with dynamic bounds + @qml.for_loop(0, upper_bound, 1) + def loop_fn(i): + qml.Hadamard(wires=i) + + # nested for-while loops. + @qml.for_loop(0, upper_bound, 1) + def loop_fn_returns(i, x): + qml.RX(x, wires=i) + + # inner while loop + @qml.while_loop(lambda j: j < upper_bound) + def inner(j): + qml.RZ(j, wires=0) + qml.RY(x**2, wires=0) + return j + 1 + + inner(i + 1) + + return x + 0.1 + + loop_fn() + loop_fn_returns(arg) + + return qml.expval(qml.Z(0)) + + args = [upper_bound, arg] + result = circuit(*args) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + jaxpr = jax.make_jaxpr(circuit)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + def test_pytree_inputs(): """Test that for_loop works with pytree inputs and outputs.""" diff --git a/tests/capture/test_capture_while_loop.py b/tests/capture/test_capture_while_loop.py index 33e9466ab78..d87f6299ba7 100644 --- a/tests/capture/test_capture_while_loop.py +++ b/tests/capture/test_capture_while_loop.py @@ -219,6 +219,43 @@ def inner(j): res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + @pytest.mark.parametrize("upper_bound, arg", [(3, 0.5), (2, 12)]) + def test_while_and_for_loop_nested(self, upper_bound, arg): + """Test that a nested while and for loop is correctly captured into a jaxpr.""" + + dev = qml.device("default.qubit", wires=3) + + def ry_fn(arg): + qml.RY(arg, wires=1) + + @qml.qnode(dev) + def circuit(upper_bound, arg): + + # while loop with dynamic bounds + @qml.while_loop(lambda i: i < upper_bound) + def loop_fn(i): + qml.Hadamard(wires=i) + + @qml.for_loop(0, i, 1) + def loop_fn_returns(i, x): + qml.RX(x, wires=i) + m_0 = qml.measure(0) + qml.cond(m_0, ry_fn)(x) + return i + 1 + + loop_fn_returns(arg) + return i + 1 + + loop_fn(0) + + return qml.expval(qml.Z(0)) + + args = [upper_bound, arg] + result = circuit(*args) + jaxpr = jax.make_jaxpr(circuit)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + assert np.allclose(result, res_ev_jxpr), f"Expected {result}, but got {res_ev_jxpr}" + def test_pytree_input_output(): """Test that the while loop supports pytree input and output.""" From 18e90721e622fe9a74b3068ff26815f8f8572cd1 Mon Sep 17 00:00:00 2001 From: ringo-but-quantum Date: Mon, 2 Sep 2024 09:51:53 +0000 Subject: [PATCH 2/4] [no ci] bump nightly version --- pennylane/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/_version.py b/pennylane/_version.py index e94ae8e0a64..614dbc746f5 100644 --- a/pennylane/_version.py +++ b/pennylane/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.39.0-dev4" +__version__ = "0.39.0-dev5" From 43dcbffe83bce05640c518543e12e2160a00b34b Mon Sep 17 00:00:00 2001 From: ringo-but-quantum Date: Tue, 3 Sep 2024 09:51:31 +0000 Subject: [PATCH 3/4] [no ci] bump nightly version --- pennylane/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/_version.py b/pennylane/_version.py index 614dbc746f5..3c61a75803b 100644 --- a/pennylane/_version.py +++ b/pennylane/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.39.0-dev5" +__version__ = "0.39.0-dev6" From 20c024b8aa0c653d89b2ea23457b71c27d837093 Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 3 Sep 2024 13:11:48 -0400 Subject: [PATCH 4/4] `PrepSelPrep` template works with `torch` (#6191) This PR fixes bug #6185 --------- Co-authored-by: ringo-but-quantum <> Co-authored-by: Guillermo Alonso-Linaje <65235481+KetpuntoG@users.noreply.github.com> --- doc/releases/changelog-dev.md | 4 ++++ .../templates/subroutines/prepselprep.py | 15 ++++---------- .../test_subroutines/test_prepselprep.py | 20 +++++++++++++++++++ 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index b859ce0f156..51e2fb5f36f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -24,6 +24,9 @@ * Fix Pytree serialization of operators with empty shot vectors: [(#6155)](https://github.com/PennyLaneAI/pennylane/pull/6155) +* Fix `qml.PrepSelPrep` template to work with `torch`: + [(#6191)](https://github.com/PennyLaneAI/pennylane/pull/6191) +

Contributors ✍️

This release contains contributions from (in alphabetical order): @@ -31,3 +34,4 @@ This release contains contributions from (in alphabetical order): Utkarsh Azad Jack Brown Christina Lee +William Maxwell diff --git a/pennylane/templates/subroutines/prepselprep.py b/pennylane/templates/subroutines/prepselprep.py index 53df7f96cfc..c9796427615 100644 --- a/pennylane/templates/subroutines/prepselprep.py +++ b/pennylane/templates/subroutines/prepselprep.py @@ -23,22 +23,15 @@ def _get_new_terms(lcu): """Compute a new sum of unitaries with positive coefficients""" - - new_coeffs = [] + coeffs, ops = lcu.terms() + angles = qml.math.angle(coeffs) new_ops = [] - for coeff, op in zip(*lcu.terms()): - - angle = qml.math.angle(coeff) - new_coeffs.append(qml.math.abs(coeff)) - + for angle, op in zip(angles, ops): new_op = op @ qml.GlobalPhase(-angle, wires=op.wires) new_ops.append(new_op) - interface = qml.math.get_interface(lcu.terms()[0]) - new_coeffs = qml.math.array(new_coeffs, like=interface) - - return new_coeffs, new_ops + return qml.math.abs(coeffs), new_ops class PrepSelPrep(Operation): diff --git a/tests/templates/test_subroutines/test_prepselprep.py b/tests/templates/test_subroutines/test_prepselprep.py index 82629973865..95e7f771ef7 100644 --- a/tests/templates/test_subroutines/test_prepselprep.py +++ b/tests/templates/test_subroutines/test_prepselprep.py @@ -315,6 +315,26 @@ class TestInterfaces: params = np.array([0.4, 0.5, 0.1, 0.3]) exp_grad = [0.41177732, -0.21262349, 1.6437038, -0.74256516] + @pytest.mark.torch + def test_torch(self): + """Test the torch interface""" + import torch + + dev = qml.device("default.qubit") + + @qml.qnode(dev) + def circuit(coeffs): + H = qml.ops.LinearCombination( + coeffs, [qml.Y(0), qml.Y(1) @ qml.Y(2), qml.X(0), qml.X(1) @ qml.X(2)] + ) + qml.PrepSelPrep(H, control=(3, 4)) + return qml.expval(qml.PauliZ(3) @ qml.PauliZ(4)) + + params = torch.tensor(self.params) + res = torch.autograd.functional.jacobian(circuit, params) + assert qml.math.shape(res) == (4,) + assert np.allclose(res, self.exp_grad, atol=1e-5) + @pytest.mark.autograd def test_autograd(self): """Test the autograd interface"""