From b67eea1578b2faf0f34f38697f2dff88f474f6ce Mon Sep 17 00:00:00 2001 From: lillian542 Date: Wed, 24 Jul 2024 18:26:01 -0400 Subject: [PATCH] update assert_valid to catch CNOT decomposition --- pennylane/ops/functions/assert_valid.py | 3 +++ tests/ops/functions/test_assert_valid.py | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/pennylane/ops/functions/assert_valid.py b/pennylane/ops/functions/assert_valid.py index bc92514a4ce..4befa9f35c5 100644 --- a/pennylane/ops/functions/assert_valid.py +++ b/pennylane/ops/functions/assert_valid.py @@ -56,6 +56,9 @@ def _check_decomposition(op, skip_wire_mapping): assert isinstance(decomp, list), "decomposition must be a list" assert isinstance(compute_decomp, list), "decomposition must be a list" + assert op.__class__ not in [ + decomp_op.__class__ for decomp_op in decomp + ], "an operator should not be included in its own decomposition" for o1, o2, o3 in zip(decomp, compute_decomp, processed_queue): assert o1 == o2, "decomposition must match compute_decomposition" diff --git a/tests/ops/functions/test_assert_valid.py b/tests/ops/functions/test_assert_valid.py index b9a6cf45368..bbbda986b2d 100644 --- a/tests/ops/functions/test_assert_valid.py +++ b/tests/ops/functions/test_assert_valid.py @@ -80,6 +80,17 @@ def compute_decomposition(wires): with pytest.raises(AssertionError, match="If has_decomposition is False"): assert_valid(BadDecomp(wires=0), skip_pickle=True) + def test_decomposition_must_not_contain_op(self): + """Test that decomposition op an operator doesn't include the operator itself""" + + class BadDecomp(Operator): + @staticmethod + def compute_decomposition(wires): + return [BadDecomp(wires)] + + with pytest.raises(AssertionError, match="should not be included in its own decomposition"): + assert_valid(BadDecomp(wires=0), skip_pickle=True) + class TestBadMatrix: """Tests involving matrix validation."""