Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
Browse files Browse the repository at this point in the history
…o mini-dev-new
  • Loading branch information
astralcai committed Sep 3, 2024
2 parents 2f849ff + 20c024b commit 2264edf
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 12 deletions.
8 changes: 8 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

<h3>Improvements 🛠</h3>

* Improve unit testing for capturing of nested control flows.
[(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111)

* Some custom primitives for the capture project can now be imported via
`from pennylane.capture.primitives import *`.
[(#6129)](https://github.com/PennyLaneAI/pennylane/pull/6129)
Expand All @@ -21,9 +24,14 @@
* Fix Pytree serialization of operators with empty shot vectors:
[(#6155)](https://github.com/PennyLaneAI/pennylane/pull/6155)

* Fix `qml.PrepSelPrep` template to work with `torch`:
[(#6191)](https://github.com/PennyLaneAI/pennylane/pull/6191)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Utkarsh Azad
Jack Brown
Christina Lee
William Maxwell
2 changes: 1 addition & 1 deletion pennylane/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.39.0-dev4"
__version__ = "0.39.0-dev6"
15 changes: 4 additions & 11 deletions pennylane/templates/subroutines/prepselprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,15 @@

def _get_new_terms(lcu):
"""Compute a new sum of unitaries with positive coefficients"""

new_coeffs = []
coeffs, ops = lcu.terms()
angles = qml.math.angle(coeffs)
new_ops = []

for coeff, op in zip(*lcu.terms()):

angle = qml.math.angle(coeff)
new_coeffs.append(qml.math.abs(coeff))

for angle, op in zip(angles, ops):
new_op = op @ qml.GlobalPhase(-angle, wires=op.wires)
new_ops.append(new_op)

interface = qml.math.get_interface(lcu.terms()[0])
new_coeffs = qml.math.array(new_coeffs, like=interface)

return new_coeffs, new_ops
return qml.math.abs(coeffs), new_ops


class PrepSelPrep(Operation):
Expand Down
67 changes: 67 additions & 0 deletions tests/capture/test_capture_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,73 @@ def f(*x):

assert np.allclose(res, expected, atol=atol, rtol=0), f"Expected {expected}, but got {res}"

@pytest.mark.parametrize("upper_bound, arg", [(3, [0.1, 0.3, 0.5]), (2, [2, 7, 12])])
def test_nested_cond_for_while_loop(self, upper_bound, arg):
"""Test that a nested control flows are correctly captured into a jaxpr."""

dev = qml.device("default.qubit", wires=3)

# Control flow for qml.conds
def true_fn(_):
@qml.for_loop(0, upper_bound, 1)
def loop_fn(i):
qml.Hadamard(wires=i)

loop_fn()

def elif_fn(arg):
qml.RY(arg**2, wires=[2])

def false_fn(arg):
qml.RY(-arg, wires=[2])

@qml.qnode(dev)
def circuit(upper_bound, arg):
qml.RY(-np.pi / 2, wires=[2])
m_0 = qml.measure(2)

# NOTE: qml.cond(m_0, qml.RX)(arg[1], wires=1) doesn't work
def rx_fn():
qml.RX(arg[1], wires=1)

qml.cond(m_0, rx_fn)()

def ry_fn():
qml.RY(arg[1] ** 3, wires=1)

# nested for loops.
# outer for loop updates x
@qml.for_loop(0, upper_bound, 1)
def loop_fn_returns(i, x):
qml.RX(x, wires=i)
m_1 = qml.measure(0)
# NOTE: qml.cond(m_0, qml.RY)(arg[1], wires=1) doesn't work
qml.cond(m_1, ry_fn)()

# inner while loop
@qml.while_loop(lambda j: j < upper_bound)
def inner(j):
qml.RZ(j, wires=0)
qml.RY(x**2, wires=0)
m_2 = qml.measure(0)
qml.cond(m_2, true_fn=true_fn, false_fn=false_fn, elifs=((m_1, elif_fn)))(
arg[0]
)
return j + 1

inner(i + 1)
return x + 0.1

loop_fn_returns(arg[2])

return qml.expval(qml.Z(0))

args = [upper_bound, arg]
result = circuit(*args)
jaxpr = jax.make_jaxpr(circuit)(*args)
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, upper_bound, *arg)
assert np.allclose(result, res_ev_jxpr), f"Expected {result}, but got {res_ev_jxpr}"


class TestPytree:
"""Test pytree support for cond."""
Expand Down
45 changes: 45 additions & 0 deletions tests/capture/test_capture_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,51 @@ def inner(j):
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.parametrize(
"upper_bound, arg, expected", [(3, 0.5, 0.00223126), (2, 12, 0.2653001)]
)
def test_nested_for_and_while_loop(self, upper_bound, arg, expected):
"""Test that a nested for loop and while loop is correctly captured into a jaxpr."""

dev = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def circuit(upper_bound, arg):

# for loop with dynamic bounds
@qml.for_loop(0, upper_bound, 1)
def loop_fn(i):
qml.Hadamard(wires=i)

# nested for-while loops.
@qml.for_loop(0, upper_bound, 1)
def loop_fn_returns(i, x):
qml.RX(x, wires=i)

# inner while loop
@qml.while_loop(lambda j: j < upper_bound)
def inner(j):
qml.RZ(j, wires=0)
qml.RY(x**2, wires=0)
return j + 1

inner(i + 1)

return x + 0.1

loop_fn()
loop_fn_returns(arg)

return qml.expval(qml.Z(0))

args = [upper_bound, arg]
result = circuit(*args)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

jaxpr = jax.make_jaxpr(circuit)(*args)
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"


def test_pytree_inputs():
"""Test that for_loop works with pytree inputs and outputs."""
Expand Down
37 changes: 37 additions & 0 deletions tests/capture/test_capture_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,43 @@ def inner(j):
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.parametrize("upper_bound, arg", [(3, 0.5), (2, 12)])
def test_while_and_for_loop_nested(self, upper_bound, arg):
"""Test that a nested while and for loop is correctly captured into a jaxpr."""

dev = qml.device("default.qubit", wires=3)

def ry_fn(arg):
qml.RY(arg, wires=1)

@qml.qnode(dev)
def circuit(upper_bound, arg):

# while loop with dynamic bounds
@qml.while_loop(lambda i: i < upper_bound)
def loop_fn(i):
qml.Hadamard(wires=i)

@qml.for_loop(0, i, 1)
def loop_fn_returns(i, x):
qml.RX(x, wires=i)
m_0 = qml.measure(0)
qml.cond(m_0, ry_fn)(x)
return i + 1

loop_fn_returns(arg)
return i + 1

loop_fn(0)

return qml.expval(qml.Z(0))

args = [upper_bound, arg]
result = circuit(*args)
jaxpr = jax.make_jaxpr(circuit)(*args)
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
assert np.allclose(result, res_ev_jxpr), f"Expected {result}, but got {res_ev_jxpr}"


def test_pytree_input_output():
"""Test that the while loop supports pytree input and output."""
Expand Down
20 changes: 20 additions & 0 deletions tests/templates/test_subroutines/test_prepselprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,26 @@ class TestInterfaces:
params = np.array([0.4, 0.5, 0.1, 0.3])
exp_grad = [0.41177732, -0.21262349, 1.6437038, -0.74256516]

@pytest.mark.torch
def test_torch(self):
"""Test the torch interface"""
import torch

dev = qml.device("default.qubit")

@qml.qnode(dev)
def circuit(coeffs):
H = qml.ops.LinearCombination(
coeffs, [qml.Y(0), qml.Y(1) @ qml.Y(2), qml.X(0), qml.X(1) @ qml.X(2)]
)
qml.PrepSelPrep(H, control=(3, 4))
return qml.expval(qml.PauliZ(3) @ qml.PauliZ(4))

params = torch.tensor(self.params)
res = torch.autograd.functional.jacobian(circuit, params)
assert qml.math.shape(res) == (4,)
assert np.allclose(res, self.exp_grad, atol=1e-5)

@pytest.mark.autograd
def test_autograd(self):
"""Test the autograd interface"""
Expand Down

0 comments on commit 2264edf

Please sign in to comment.