Skip to content

Commit

Permalink
Treat aux_wire and device_wires the same between metric_tensor
Browse files Browse the repository at this point in the history
…and `hadamard_grad` (#4328)

* update and tests

* changelog

* remove print

* second bug

* changelog 2

* legacy tests

* black

* Update tests/transforms/test_metric_tensor.py

* fix

* review

* trigger
  • Loading branch information
dwierichs authored Jul 12, 2023
1 parent b781be0 commit 18d508a
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 55 deletions.
11 changes: 10 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

<h3>Improvements 🛠</h3>

* Treat auxiliary wires and device wires in the same way in `transforms.metric_tensor`
as in `gradients.hadamard_grad`. Support all valid wire input formats for `aux_wire`.
[(#4328)](https://github.com/PennyLaneAI/pennylane/pull/4328)

* `qml.equal` no longer raises errors when operators or measurements of different types are compared.
Instead, it returns `False`.
[(#4315)](https://github.com/PennyLaneAI/pennylane/pull/4315)
Expand Down Expand Up @@ -76,6 +80,10 @@
<h3>Documentation 📝</h3>

<h3>Bug fixes 🐛</h3>

* Stop `metric_tensor` from accidentally catching errors that stem from
flawed wires assignments in the original circuit, leading to recursion errors
[(#4328)](https://github.com/PennyLaneAI/pennylane/pull/4328)

* Raise a warning if control indicators are hidden when calling `qml.draw_mpl`
[(#4295)](https://github.com/PennyLaneAI/pennylane/pull/4295)
Expand All @@ -91,4 +99,5 @@ Edward Jiang,
Christina Lee,
Mudit Pandey,
Borja Requena,
Matthew Silverman
Matthew Silverman,
David Wierichs,
14 changes: 2 additions & 12 deletions pennylane/gradients/hadamard_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,8 @@ def _hadamard_grad(

argnum = [i for i, dm in method_map.items() if dm == "A"]

if device_wires and len(tape.wires) == len(device_wires):
raise qml.QuantumFunctionError("The device has no free wire for the auxiliary wire.")

# Get default for aux_wire
if aux_wire is None:
aux_wire = _get_aux_wire(aux_wire, tape, device_wires)
elif aux_wire[0] in tape.wires:
raise qml.QuantumFunctionError("The auxiliary wire is already used.")
elif aux_wire[0] not in device_wires:
raise qml.QuantumFunctionError(
"The requested auxiliary wire does not exist on the used device."
)
# Validate or get default for aux_wire
aux_wire = _get_aux_wire(aux_wire, tape, device_wires)

g_tapes, processing_fn = _expval_hadamard_grad(tape, argnum, aux_wire)

Expand Down
51 changes: 35 additions & 16 deletions pennylane/transforms/metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ def metric_tensor(
allow_nonunitary (bool): Whether non-unitary operations are allowed in circuits
created by the transform. Only relevant if ``approx`` is ``None``.
Should be set to ``True`` if possible to reduce cost.
aux_wire (int or str or pennylane.wires.Wires): Auxiliary wire to be used for
Hadamard tests. If ``None`` (the default), a suitable wire is inferred
from the (number of) used wires in the original circuit and ``device_wires``.
aux_wire (None or int or str or Sequence or pennylane.wires.Wires): Auxiliary wire to
be used for Hadamard tests. If ``None`` (the default), a suitable wire is inferred
from the (number of) used wires in the original circuit and ``device_wires``,
if the latter are given.
device_wires (.wires.Wires): Wires of the device that is going to be used for the
metric tensor. Facilitates finding a default for ``aux_wire`` if ``aux_wire``
is ``None``.
Expand Down Expand Up @@ -424,22 +425,29 @@ def wrapper(*args, **kwargs): # pylint: disable=too-many-branches
try:
mt = mt_fn(*args, **kwargs)
except qml.wires.WireError as e:
if str(e) == "No device wires are unused by the tape.":
revert_text = (
"\n\nReverting to the block-diagonal approximation. It will often be "
"much more efficient to request the block-diagonal approximation directly!"
)
other_mt_errors = [
"The requested auxiliary wire is already in use by the circuit.",
"The requested auxiliary wire does not exist on the used device.",
]

if str(e) == "The device has no free wire for the auxiliary wire.":
warnings.warn(
"The device does not have a wire that is not used by the tape."
"\n\nReverting to the block-diagonal approximation. It will often be "
"much more efficient to request the block-diagonal approximation directly!"
"The device does not have a wire that is not used by the circuit." + revert_text
)
else:
elif str(e) in other_mt_errors:
warnings.warn(
"An auxiliary wire is not available."
"\n\nThis can occur when computing the full metric tensor via the "
"Hadamard test, and the device does not provide an "
"additional wire or the requested auxiliary wire does not exist "
"on the device."
"\n\nReverting to the block-diagonal approximation. It will often be "
"much more efficient to request the block-diagonal approximation directly!"
"on the device." + revert_text
)
else:
raise e
tkwargs["approx"] = "block-diag"
return self(qnode, *targs, **tkwargs)(*args, **kwargs)

Expand Down Expand Up @@ -853,25 +861,36 @@ def _get_aux_wire(aux_wire, tape, device_wires):
r"""Determine an unused wire to be used as auxiliary wire for Hadamard tests.
Args:
aux_wire (object): Input auxiliary wire. Returned unmodified if not ``None``
aux_wire (object): Input auxiliary wire. May be one of a variety of input formats:
If ``None``, try to infer a reasonable choice based on the number of wires used
in the ``tape``, and based on ``device_wires``, if they are not ``None``.
If an ``int``, a ``str`` or a ``Sequence``, convert the input to a ``Wires``
object and take the first entry of the result. This leads to consistent behaviour
between ``_get_aux_wire`` and the ``Wires`` class.
If a ``Wires`` instance already, the conversion to such an instance is performed
trivially as well (also see the source code of ``~.Wires``).
tape (pennylane.tape.QuantumTape): Tape to infer the wire for
device_wires (.wires.Wires): Wires of the device that is going to be used for the
metric tensor. Facilitates finding a default for ``aux_wire`` if ``aux_wire``
is ``None`` .
Returns:
object: The auxiliary wire to be used. Equals ``aux_wire`` if it was not ``None`` ,
object: The auxiliary wire to be used. Equals ``aux_wire`` if it was not ``None``\ ,
and an often reasonable choice else.
"""
if aux_wire is not None:
aux_wire = qml.wires.Wires(aux_wire)[0]
if aux_wire in tape.wires:
msg = "The requested auxiliary wire is already in use by the circuit."
raise qml.wires.WireError(msg)
if device_wires is None or aux_wire in device_wires:
return aux_wire
raise qml.wires.WireError("The requested aux_wire does not exist on the used device.")
raise qml.wires.WireError("The requested auxiliary wire does not exist on the used device.")

if device_wires is not None:
if len(device_wires) == len(tape.wires):
raise qml.wires.WireError("The device has no free wire for the auxiliary wire.")
unused_wires = qml.wires.Wires(device_wires.toset().difference(tape.wires.toset()))
if not unused_wires:
raise qml.wires.WireError("No device wires are unused by the tape.")
return unused_wires[0]

_wires = tape.wires
Expand Down
47 changes: 28 additions & 19 deletions tests/gradients/core/test_hadamard_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ class TestHadamardGradEdgeCases:
device_wires = [qml.wires.Wires([0, 1, "aux"])]
device_wires_no_aux = [qml.wires.Wires([0, 1, 2])]

working_wires = [None, qml.wires.Wires("aux")]
working_wires = [None, qml.wires.Wires("aux"), "aux"]
already_used_wires = [qml.wires.Wires(0), qml.wires.Wires(1)]

@pytest.mark.parametrize("aux_wire", working_wires)
Expand All @@ -566,6 +566,8 @@ def test_aux_wire(self, aux_wire, device_wires):

tapes, _ = qml.gradients.hadamard_grad(tape, aux_wire=aux_wire, device_wires=dev.wires)
assert len(tapes) == 2
tapes, _ = qml.gradients.hadamard_grad(tape, aux_wire=aux_wire)
assert len(tapes) == 2

@pytest.mark.parametrize("aux_wire", already_used_wires)
@pytest.mark.parametrize("device_wires", device_wires)
Expand All @@ -583,7 +585,8 @@ def test_aux_wire_already_used_wires(self, aux_wire, device_wires):

tape = qml.tape.QuantumScript.from_queue(q)

with pytest.raises(qml.QuantumFunctionError, match="The auxiliary wire is already."):
_match = "The requested auxiliary wire is already in use by the circuit"
with pytest.raises(qml.wires.WireError, match=_match):
qml.gradients.hadamard_grad(tape, aux_wire=aux_wire, device_wires=dev.wires)

@pytest.mark.parametrize("device_wires", device_wires_no_aux)
Expand All @@ -601,30 +604,36 @@ def test_requested_wire_not_exist(self, device_wires):
qml.expval(qml.PauliZ(0) @ qml.PauliX(1))

tape = qml.tape.QuantumScript.from_queue(q)
with pytest.raises(
qml.QuantumFunctionError,
match="The requested auxiliary wire does not exist on the used device.",
):
_match = "The requested auxiliary wire does not exist on the used device"
with pytest.raises(qml.wires.WireError, match=_match):
qml.gradients.hadamard_grad(tape, aux_wire=aux_wire, device_wires=dev.wires)

@pytest.mark.parametrize("aux_wire", working_wires + already_used_wires)
@pytest.mark.parametrize("aux_wire", [None] + already_used_wires)
def test_device_not_enough_wires(self, aux_wire):
"""Test that an error is raised when the device cannot accept an auxiliary wire because it is full."""
"""Test that an error is raised when the device cannot accept an auxiliary wire
because it is full."""
dev = qml.device("default.qubit", wires=2)
x = 0.543
y = -0.654

with qml.queuing.AnnotatedQueue() as q:
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(0) @ qml.PauliX(1))
m = qml.expval(qml.PauliZ(0) @ qml.PauliX(1))
tape = qml.tape.QuantumScript([qml.RX(0.543, wires=[0]), qml.RY(-0.654, wires=[1])], [m])

tape = qml.tape.QuantumScript.from_queue(q)
if aux_wire is None:
_match = "The device has no free wire for the auxiliary wire."
else:
_match = "The requested auxiliary wire is already in use by the circuit."
with pytest.raises(qml.wires.WireError, match=_match):
qml.gradients.hadamard_grad(tape, aux_wire=aux_wire, device_wires=dev.wires)

with pytest.raises(
qml.QuantumFunctionError, match="The device has no free wire for the auxiliary wire."
):
def test_device_wire_does_not_exist(self):
"""Test that an error is raised when the device cannot accept an auxiliary wire
because it does not exist on the device."""
aux_wire = qml.wires.Wires("aux")
dev = qml.device("default.qubit", wires=2)
m = qml.expval(qml.PauliZ(0) @ qml.PauliX(1))
tape = qml.tape.QuantumScript([qml.RX(0.543, wires=[0]), qml.RY(-0.654, wires=[1])], [m])

_match = "The requested auxiliary wire does not exist on the used device."
with pytest.raises(qml.wires.WireError, match=_match):
qml.gradients.hadamard_grad(tape, aux_wire=aux_wire, device_wires=dev.wires)

def test_empty_circuit(self):
Expand Down
12 changes: 8 additions & 4 deletions tests/legacy/test_legacy_metric_tensor_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def circuit(weights):

error_msg = (
"Some parameters specified in argnum are not in the "
"trainable parameters \[0, 1, 2, 3\] of the tape "
r"trainable parameters \[0, 1, 2, 3\] of the tape "
"and will be ignored. This may be caused by attempting to "
"differentiate with respect to parameters that are not marked "
"as trainable."
Expand Down Expand Up @@ -1758,9 +1758,13 @@ def test_get_aux_wire_with_device_wires():
tape = qml.tape.QuantumScript.from_queue(q)
device_wires = qml.wires.Wires([0, "aux", "one"])

assert _get_aux_wire(0, tape, device_wires) == 0
assert _get_aux_wire("one", tape, device_wires) == "one"
assert _get_aux_wire(None, tape, device_wires) == "aux"
assert _get_aux_wire("aux", tape, device_wires) == "aux"
_match = "The requested auxiliary wire is already in use by the circuit."
with pytest.raises(qml.wires.WireError, match=_match):
_get_aux_wire("one", tape, device_wires)
with pytest.raises(qml.wires.WireError, match=_match):
_get_aux_wire(0, tape, device_wires)


def test_get_aux_wire_with_unavailable_aux():
Expand All @@ -1771,5 +1775,5 @@ def test_get_aux_wire_with_unavailable_aux():
qml.RX(x, wires="one")
tape = qml.tape.QuantumScript.from_queue(q)
device_wires = qml.wires.Wires([0, "one"])
with pytest.raises(qml.wires.WireError, match="The requested aux_wire does not exist"):
with pytest.raises(qml.wires.WireError, match="The requested auxiliary wire does not exist"):
_get_aux_wire("two", tape, device_wires)
29 changes: 26 additions & 3 deletions tests/transforms/test_metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,25 @@ def circuit_multi_block(x, z):
assert len(recwarn) == 0


def test_raises_circuit_that_uses_missing_wire():
"""Test that an error in the original circuit is reraised properly and not caught. This avoids
accidentally catching relevant errors, which can lead to a recursion error."""

dev = qml.device("default.qubit", wires=[0, "b"])

@qml.qnode(dev)
def circuit(x):
"""Flawed circuit that uses a wire which is not on the device."""
qml.RX(x[0], 0)
qml.CNOT([0, 1]) # wire 1 is not on the device
qml.RX(x[1], 0)
return qml.expval(qml.PauliZ(0))

x = np.array([1.3, 0.2])
with pytest.raises(qml.wires.WireError, match=r"Did not find some of the wires \(0, 1\)"):
qml.transforms.metric_tensor(circuit)(x)


def aux_wire_ansatz_0(x, y):
qml.RX(x, wires=0)
qml.RY(y, wires=2)
Expand Down Expand Up @@ -1781,9 +1800,13 @@ def test_get_aux_wire_with_device_wires():
tape = qml.tape.QuantumScript.from_queue(q)
device_wires = qml.wires.Wires([0, "aux", "one"])

assert _get_aux_wire(0, tape, device_wires) == 0
assert _get_aux_wire("one", tape, device_wires) == "one"
assert _get_aux_wire(None, tape, device_wires) == "aux"
assert _get_aux_wire("aux", tape, device_wires) == "aux"
_match = "The requested auxiliary wire is already in use by the circuit."
with pytest.raises(qml.wires.WireError, match=_match):
_get_aux_wire("one", tape, device_wires)
with pytest.raises(qml.wires.WireError, match=_match):
_get_aux_wire(0, tape, device_wires)


def test_get_aux_wire_with_unavailable_aux():
Expand All @@ -1794,5 +1817,5 @@ def test_get_aux_wire_with_unavailable_aux():
qml.RX(y, wires="one")
tape = qml.tape.QuantumScript.from_queue(q)
device_wires = qml.wires.Wires([0, "one"])
with pytest.raises(qml.wires.WireError, match="The requested aux_wire does not exist"):
with pytest.raises(qml.wires.WireError, match="The requested auxiliary wire does not exist"):
_get_aux_wire("two", tape, device_wires)

0 comments on commit 18d508a

Please sign in to comment.