diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 3b9c85ace17..9c030886b8d 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -121,6 +121,9 @@
Hamiltonians.
[(#5950)](https://github.com/PennyLaneAI/pennylane/pull/5950)
+* The `CNOT` operator no longer decomposes to itself. Instead, it raises a `qml.DecompositionUndefinedError`.
+ [(#6039)](https://github.com/PennyLaneAI/pennylane/pull/6039)
+
Community contributions 🥳
* `DefaultQutritMixed` readout error has been added using parameters `readout_relaxation_probs` and
diff --git a/pennylane/ops/functions/assert_valid.py b/pennylane/ops/functions/assert_valid.py
index bc92514a4ce..e3795ff6608 100644
--- a/pennylane/ops/functions/assert_valid.py
+++ b/pennylane/ops/functions/assert_valid.py
@@ -56,6 +56,9 @@ def _check_decomposition(op, skip_wire_mapping):
assert isinstance(decomp, list), "decomposition must be a list"
assert isinstance(compute_decomp, list), "decomposition must be a list"
+ assert op.__class__ not in [
+ decomp_op.__class__ for decomp_op in decomp
+ ], "an operator should not be included in its own decomposition"
for o1, o2, o3 in zip(decomp, compute_decomp, processed_queue):
assert o1 == o2, "decomposition must match compute_decomposition"
@@ -67,12 +70,17 @@ def _check_decomposition(op, skip_wire_mapping):
# Check that mapping wires transitions to the decomposition
wire_map = {w: ascii_lowercase[i] for i, w in enumerate(op.wires)}
mapped_op = op.map_wires(wire_map)
- mapped_decomp = mapped_op.decomposition()
- orig_decomp = op.decomposition()
- for mapped_op, orig_op in zip(mapped_decomp, orig_decomp):
- assert (
- mapped_op.wires == qml.map_wires(orig_op, wire_map).wires
- ), "Operators in decomposition of wire-mapped operator must have mapped wires."
+ # calling `map_wires` on a Controlled operator generates a new `op` from the controls and
+ # base, so may return a different class of operator. We only compare decomps of `op` and
+ # `mapped_op` if `mapped_op` **has** a decomposition.
+ # see MultiControlledX([0, 1]) and CNOT([0, 1]) as an example
+ if mapped_op.has_decomposition:
+ mapped_decomp = mapped_op.decomposition()
+ orig_decomp = op.decomposition()
+ for mapped_op, orig_op in zip(mapped_decomp, orig_decomp):
+ assert (
+ mapped_op.wires == qml.map_wires(orig_op, wire_map).wires
+ ), "Operators in decomposition of wire-mapped operator must have mapped wires."
else:
failure_comment = "If has_decomposition is False, then decomposition must raise a ``DecompositionUndefinedError``."
_assert_error_raised(
diff --git a/pennylane/ops/op_math/controlled_ops.py b/pennylane/ops/op_math/controlled_ops.py
index d2c5d3afb4a..4d2a8475f6b 100644
--- a/pennylane/ops/op_math/controlled_ops.py
+++ b/pennylane/ops/op_math/controlled_ops.py
@@ -749,10 +749,10 @@ class CNOT(ControlledOp):
The controlled-NOT operator
.. math:: CNOT = \begin{bmatrix}
- 1 & 0 & 0 & 0 \\
- 0 & 1 & 0 & 0\\
- 0 & 0 & 0 & 1\\
- 0 & 0 & 1 & 0
+ 1 & 0 & 0 & 0 \\
+ 0 & 1 & 0 & 0\\
+ 0 & 0 & 0 & 1\\
+ 0 & 0 & 1 & 0
\end{bmatrix}.
.. note:: The first wire provided corresponds to the **control qubit**.
@@ -791,6 +791,32 @@ def _primitive_bind_call(cls, wires, id=None):
def __init__(self, wires, id=None):
super().__init__(qml.PauliX(wires=wires[1:]), wires[:1], id=id)
+ @property
+ def has_decomposition(self):
+ return False
+
+ @staticmethod
+ def compute_decomposition(*params, wires=None, **hyperparameters): # -> List["Operator"]:
+ r"""Representation of the operator as a product of other operators (static method).
+
+ .. math:: O = O_1 O_2 \dots O_n.
+
+ .. note::
+ Operations making up the decomposition should be queued within the
+ ``compute_decomposition`` method.
+
+ .. seealso:: :meth:`~.Operator.decomposition`.
+
+ Args:
+ *params (list): trainable parameters of the operator, as stored in the ``parameters`` attribute
+ wires (Iterable[Any], Wires): wires that the operator acts on
+ **hyperparams (dict): non-trainable hyperparameters of the operator, as stored in the ``hyperparameters`` attribute
+
+ Raises:
+ qml.DecompositionUndefinedError
+ """
+ raise qml.operation.DecompositionUndefinedError
+
def __repr__(self):
return f"CNOT(wires={self.wires.tolist()})"
diff --git a/pennylane/transforms/tape_expand.py b/pennylane/transforms/tape_expand.py
index f3e9b0d26e5..60580eb3f67 100644
--- a/pennylane/transforms/tape_expand.py
+++ b/pennylane/transforms/tape_expand.py
@@ -297,14 +297,17 @@ def _custom_decomposition(obj, fn):
obj = getattr(qml, obj)
original_decomp_method = obj.compute_decomposition
+ original_has_decomp_property = obj.has_decomposition
try:
# Explicitly set the new compute_decomposition method
obj.compute_decomposition = staticmethod(fn)
+ obj.has_decomposition = lambda obj: True
yield
finally:
obj.compute_decomposition = staticmethod(original_decomp_method)
+ obj.has_decomposition = original_has_decomp_property
# Loop through the decomposition dictionary and create all the contexts
try:
diff --git a/tests/ops/functions/test_assert_valid.py b/tests/ops/functions/test_assert_valid.py
index b9a6cf45368..9fff7758338 100644
--- a/tests/ops/functions/test_assert_valid.py
+++ b/tests/ops/functions/test_assert_valid.py
@@ -80,6 +80,17 @@ def compute_decomposition(wires):
with pytest.raises(AssertionError, match="If has_decomposition is False"):
assert_valid(BadDecomp(wires=0), skip_pickle=True)
+ def test_decomposition_must_not_contain_op(self):
+ """Test that the decomposition of an operator doesn't include the operator itself"""
+
+ class BadDecomp(Operator):
+ @staticmethod
+ def compute_decomposition(wires):
+ return [BadDecomp(wires)]
+
+ with pytest.raises(AssertionError, match="should not be included in its own decomposition"):
+ assert_valid(BadDecomp(wires=0), skip_pickle=True)
+
class TestBadMatrix:
"""Tests involving matrix validation."""
diff --git a/tests/ops/op_math/test_controlled_ops.py b/tests/ops/op_math/test_controlled_ops.py
index d82308daf65..7a002cccda4 100644
--- a/tests/ops/op_math/test_controlled_ops.py
+++ b/tests/ops/op_math/test_controlled_ops.py
@@ -763,3 +763,15 @@ def test_tuple_control_wires_parametric_ops(op_type):
"""Test that tuples can be provided as control wire labels."""
assert op_type(0.123, [(0, 1), 2]).wires == qml.wires.Wires([(0, 1), 2])
+
+
+def test_CNOT_decomposition():
+ """Test that CNOT raises a DecompositionUndefinedError instead of using the
+ controlled_op decomposition functions"""
+ assert not qml.CNOT((0, 1)).has_decomposition
+
+ with pytest.raises(qml.operation.DecompositionUndefinedError):
+ qml.CNOT.compute_decomposition()
+
+ with pytest.raises(qml.operation.DecompositionUndefinedError):
+ qml.CNOT([0, 1]).decomposition()