diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index b117f4b3f82..d0edbac6bfd 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -339,6 +339,9 @@
Bug fixes 🐛
+* Operators applied to all wires are now drawn correctly in a circuit with mid-circuit measurements.
+ [(#5501)](https://github.com/PennyLaneAI/pennylane/pull/5501)
+
* Fix a bug where certain unary mid-circuit measurement expressions would raise an uncaught error.
[(#5480)](https://github.com/PennyLaneAI/pennylane/pull/5480)
diff --git a/pennylane/drawer/tape_text.py b/pennylane/drawer/tape_text.py
index 8e635ef50ea..7069e4749fe 100644
--- a/pennylane/drawer/tape_text.py
+++ b/pennylane/drawer/tape_text.py
@@ -142,7 +142,8 @@ def _add_op(op, layer_str, config):
label = op.label(decimals=config.decimals, cache=config.cache).replace("\n", "")
if len(op.wires) == 0: # operation (e.g. barrier, snapshot) across all wires
- for i, s in enumerate(layer_str):
+ n_wires = len(config.wire_map)
+ for i, s in enumerate(layer_str[:n_wires]):
layer_str[i] = s + label
else:
for w in op.wires:
@@ -225,7 +226,8 @@ def _add_measurement(m, layer_str, config):
meas_label = m.return_type.value
if len(m.wires) == 0: # state or probability across all wires
- for i, s in enumerate(layer_str):
+ n_wires = len(config.wire_map)
+ for i, s in enumerate(layer_str[:n_wires]):
layer_str[i] = s + meas_label
for w in m.wires:
diff --git a/tests/drawer/test_draw.py b/tests/drawer/test_draw.py
index 1bd7790f523..b571ed4fd0d 100644
--- a/tests/drawer/test_draw.py
+++ b/tests/drawer/test_draw.py
@@ -284,6 +284,8 @@ def circ():
class TestMidCircuitMeasurements:
"""Tests for drawing mid-circuit measurements and classical conditions."""
+ # pylint: disable=too-many-public-methods
+
@pytest.mark.parametrize("device_name", ["default.qubit"])
def test_qnode_mid_circuit_measurement_not_deferred(self, device_name, mocker):
"""Test that a circuit containing mid-circuit measurements is transformed by the drawer
@@ -329,6 +331,51 @@ def func():
assert drawing == expected_drawing
+ @pytest.mark.parametrize(
+ "op", [qml.GlobalPhase(0.1), qml.Identity(), qml.Snapshot(), qml.Barrier()]
+ )
+ @pytest.mark.parametrize("decimals", [None, 2])
+ def test_draw_all_wire_ops(self, op, decimals):
+ """Test that operators acting on all wires are drawn correctly"""
+
+ def func():
+ qml.X(0)
+ qml.X(1)
+ m = qml.measure(0)
+ qml.cond(m, qml.X)(0)
+ qml.apply(op)
+ return qml.expval(qml.Z(0))
+
+ # Stripping to remove trailing white-space because length of white-space at the
+ # end of the drawing depends on the length of each individual line
+ drawing = qml.draw(func, decimals=decimals)().strip()
+ label = op.label(decimals=decimals).replace("\n", "")
+ expected_drawing = (
+ f"0: ──X──┤↗├──X──{label}─┤ \n1: ──X───║───║──{label}─┤ \n ╚═══╝"
+ )
+
+ assert drawing == expected_drawing
+
+ @pytest.mark.parametrize(
+ "mp, label", [(qml.sample(), "Sample"), (qml.probs(), "Probs"), (qml.counts(), "Counts")]
+ )
+ def test_draw_all_wire_measurements(self, mp, label):
+ """Test that operators acting on all wires are drawn correctly"""
+
+ def func():
+ qml.X(0)
+ qml.X(1)
+ m = qml.measure(0)
+ qml.cond(m, qml.X)(0)
+ return qml.apply(mp)
+
+ # Stripping to remove trailing white-space because length of white-space at the
+ # end of the drawing depends on the length of each individual line
+ drawing = qml.draw(func)().strip()
+ expected_drawing = f"0: ──X──┤↗├──X─┤ {label}\n1: ──X───║───║─┤ {label}\n ╚═══╝"
+
+ assert drawing == expected_drawing
+
def test_draw_mid_circuit_measurement_multiple_wires(self):
"""Test that mid-circuit measurements are correctly drawn in circuits
with multiple wires."""