Skip to content

Commit

Permalink
CNOT decomposition method returns DecompositionUndefinedError (#6039)
Browse files Browse the repository at this point in the history
**Context:**
Before CNOT inherited from `Controlled`, its decomposition method
returned a `DecompositionUndefinedError`. Now it uses the smart
decomposition functions from `Controlled`, which in this case make it
decompose to itself.

```
>>> op = qml.CNOT([0, 1])
>>> op.decomposition()
[CNOT(wires=[0, 1])]
```

The change in the decomposition method of CNOT was unintentional, and
operators should not decompose to themselves.

**Description of the Change:**
We put it back how it was:

```
>>> op = qml.CNOT([0, 1])
>>> op.decomposition()
DecompositionUndefinedError
```

**Related GitHub Issues:**
#5711

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
3 people authored Jul 26, 2024
1 parent eb7c349 commit c534c17
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 10 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@
Hamiltonians.
[(#5950)](https://github.com/PennyLaneAI/pennylane/pull/5950)

* The `CNOT` operator no longer decomposes to itself. Instead, it raises a `qml.DecompositionUndefinedError`.
[(#6039)](https://github.com/PennyLaneAI/pennylane/pull/6039)

<h4>Community contributions 🥳</h4>

* `DefaultQutritMixed` readout error has been added using parameters `readout_relaxation_probs` and
Expand Down
20 changes: 14 additions & 6 deletions pennylane/ops/functions/assert_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def _check_decomposition(op, skip_wire_mapping):

assert isinstance(decomp, list), "decomposition must be a list"
assert isinstance(compute_decomp, list), "decomposition must be a list"
assert op.__class__ not in [
decomp_op.__class__ for decomp_op in decomp
], "an operator should not be included in its own decomposition"

for o1, o2, o3 in zip(decomp, compute_decomp, processed_queue):
assert o1 == o2, "decomposition must match compute_decomposition"
Expand All @@ -67,12 +70,17 @@ def _check_decomposition(op, skip_wire_mapping):
# Check that mapping wires transitions to the decomposition
wire_map = {w: ascii_lowercase[i] for i, w in enumerate(op.wires)}
mapped_op = op.map_wires(wire_map)
mapped_decomp = mapped_op.decomposition()
orig_decomp = op.decomposition()
for mapped_op, orig_op in zip(mapped_decomp, orig_decomp):
assert (
mapped_op.wires == qml.map_wires(orig_op, wire_map).wires
), "Operators in decomposition of wire-mapped operator must have mapped wires."
# calling `map_wires` on a Controlled operator generates a new `op` from the controls and
# base, so may return a different class of operator. We only compare decomps of `op` and
# `mapped_op` if `mapped_op` **has** a decomposition.
# see MultiControlledX([0, 1]) and CNOT([0, 1]) as an example
if mapped_op.has_decomposition:
mapped_decomp = mapped_op.decomposition()
orig_decomp = op.decomposition()
for mapped_op, orig_op in zip(mapped_decomp, orig_decomp):
assert (
mapped_op.wires == qml.map_wires(orig_op, wire_map).wires
), "Operators in decomposition of wire-mapped operator must have mapped wires."
else:
failure_comment = "If has_decomposition is False, then decomposition must raise a ``DecompositionUndefinedError``."
_assert_error_raised(
Expand Down
34 changes: 30 additions & 4 deletions pennylane/ops/op_math/controlled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,10 +749,10 @@ class CNOT(ControlledOp):
The controlled-NOT operator
.. math:: CNOT = \begin{bmatrix}
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0\\
0 & 0 & 0 & 1\\
0 & 0 & 1 & 0
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0\\
0 & 0 & 0 & 1\\
0 & 0 & 1 & 0
\end{bmatrix}.
.. note:: The first wire provided corresponds to the **control qubit**.
Expand Down Expand Up @@ -791,6 +791,32 @@ def _primitive_bind_call(cls, wires, id=None):
def __init__(self, wires, id=None):
super().__init__(qml.PauliX(wires=wires[1:]), wires[:1], id=id)

@property
def has_decomposition(self):
return False

@staticmethod
def compute_decomposition(*params, wires=None, **hyperparameters): # -> List["Operator"]:
r"""Representation of the operator as a product of other operators (static method).
.. math:: O = O_1 O_2 \dots O_n.
.. note::
Operations making up the decomposition should be queued within the
``compute_decomposition`` method.
.. seealso:: :meth:`~.Operator.decomposition`.
Args:
*params (list): trainable parameters of the operator, as stored in the ``parameters`` attribute
wires (Iterable[Any], Wires): wires that the operator acts on
**hyperparams (dict): non-trainable hyperparameters of the operator, as stored in the ``hyperparameters`` attribute
Raises:
qml.DecompositionUndefinedError
"""
raise qml.operation.DecompositionUndefinedError

def __repr__(self):
return f"CNOT(wires={self.wires.tolist()})"

Expand Down
3 changes: 3 additions & 0 deletions pennylane/transforms/tape_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,17 @@ def _custom_decomposition(obj, fn):
obj = getattr(qml, obj)

original_decomp_method = obj.compute_decomposition
original_has_decomp_property = obj.has_decomposition

try:
# Explicitly set the new compute_decomposition method
obj.compute_decomposition = staticmethod(fn)
obj.has_decomposition = lambda obj: True
yield

finally:
obj.compute_decomposition = staticmethod(original_decomp_method)
obj.has_decomposition = original_has_decomp_property

# Loop through the decomposition dictionary and create all the contexts
try:
Expand Down
11 changes: 11 additions & 0 deletions tests/ops/functions/test_assert_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ def compute_decomposition(wires):
with pytest.raises(AssertionError, match="If has_decomposition is False"):
assert_valid(BadDecomp(wires=0), skip_pickle=True)

def test_decomposition_must_not_contain_op(self):
"""Test that the decomposition of an operator doesn't include the operator itself"""

class BadDecomp(Operator):
@staticmethod
def compute_decomposition(wires):
return [BadDecomp(wires)]

with pytest.raises(AssertionError, match="should not be included in its own decomposition"):
assert_valid(BadDecomp(wires=0), skip_pickle=True)


class TestBadMatrix:
"""Tests involving matrix validation."""
Expand Down
12 changes: 12 additions & 0 deletions tests/ops/op_math/test_controlled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,3 +763,15 @@ def test_tuple_control_wires_parametric_ops(op_type):
"""Test that tuples can be provided as control wire labels."""

assert op_type(0.123, [(0, 1), 2]).wires == qml.wires.Wires([(0, 1), 2])


def test_CNOT_decomposition():
"""Test that CNOT raises a DecompositionUndefinedError instead of using the
controlled_op decomposition functions"""
assert not qml.CNOT((0, 1)).has_decomposition

with pytest.raises(qml.operation.DecompositionUndefinedError):
qml.CNOT.compute_decomposition()

with pytest.raises(qml.operation.DecompositionUndefinedError):
qml.CNOT([0, 1]).decomposition()

0 comments on commit c534c17

Please sign in to comment.