Skip to content

Commit

Permalink
raise warning if control indicators are hidden in drawer (#4295)
Browse files Browse the repository at this point in the history
* raise warning if control indicators are hidden in drawer

* add actionable suggestion to warning

* fix import order

* only check if multiple target wires

* just check all control wires

* Update doc/releases/changelog-dev.md
  • Loading branch information
timmysilv authored Jun 23, 2023
1 parent 6e0d11a commit 20b676d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

<h3>Bug fixes 🐛</h3>

* Raise a warning if control indicators are hidden when calling `qml.draw_mpl`
[(#4295)](https://github.com/PennyLaneAI/pennylane/pull/4295)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Matthew Silverman
10 changes: 10 additions & 0 deletions pennylane/drawer/mpldrawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
This module contains the MPLDrawer class for creating circuit diagrams with matplotlib
"""
from collections.abc import Iterable
import warnings

has_mpl = True
try:
Expand Down Expand Up @@ -591,6 +592,15 @@ def ctrl(self, layer, wires, wires_target=None, control_values=None, options=Non
min_wire = min(wires_all)
max_wire = max(wires_all)

if len(wires_target) > 1:
min_target, max_target = min(wires_target), max(wires_target)
if any(min_target < w < max_target for w in wires_ctrl):
warnings.warn(
"Some control indicators are hidden behind an operator. Consider re-ordering "
"your circuit wires to ensure all control indicators are visible.",
UserWarning,
)

line = plt.Line2D((layer, layer), (min_wire, max_wire), **options)
self._ax.add_line(line)

Expand Down
23 changes: 23 additions & 0 deletions tests/drawer/test_mpldrawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
# pylint: disable=protected-access,wrong-import-position

import warnings
import pytest

plt = pytest.importorskip("matplotlib.pyplot")
Expand Down Expand Up @@ -429,6 +430,28 @@ def test_ctrl_target(self):
assert circle.center == (0, 0)
plt.close()

@pytest.mark.parametrize(
"control_wires,target_wires",
[
((1,), (0, 2)),
((0, 2), (1, 3)),
((1, 3), (0, 2)),
((0, 2, 4), (1, 3)),
],
)
def test_ctrl_raises_warning_with_overlap(self, control_wires, target_wires):
"""Tests that a warning is raised if some control indicators are not visible."""
drawer = MPLDrawer(1, 4)
with pytest.warns(UserWarning, match="control indicators are hidden behind an operator"):
drawer.ctrl(0, control_wires, target_wires)

@pytest.mark.parametrize("control_wires,target_wires", [((0,), (1, 2)), ((2,), (0, 1))])
def test_ctrl_no_warning_without_overlap(self, control_wires, target_wires):
drawer = MPLDrawer(1, 3)
with warnings.catch_warnings(record=True) as w:
drawer.ctrl(0, control_wires, target_wires)
assert len(w) == 0

def test_target_x(self):
"""Tests hidden target_x drawing method"""

Expand Down

0 comments on commit 20b676d

Please sign in to comment.