diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index b117f4b3f82..d0edbac6bfd 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -339,6 +339,9 @@

Bug fixes 🐛

+* Operators applied to all wires are now drawn correctly in a circuit with mid-circuit measurements. + [(#5501)](https://github.com/PennyLaneAI/pennylane/pull/5501) + * Fix a bug where certain unary mid-circuit measurement expressions would raise an uncaught error. [(#5480)](https://github.com/PennyLaneAI/pennylane/pull/5480) diff --git a/pennylane/drawer/tape_text.py b/pennylane/drawer/tape_text.py index 8e635ef50ea..7069e4749fe 100644 --- a/pennylane/drawer/tape_text.py +++ b/pennylane/drawer/tape_text.py @@ -142,7 +142,8 @@ def _add_op(op, layer_str, config): label = op.label(decimals=config.decimals, cache=config.cache).replace("\n", "") if len(op.wires) == 0: # operation (e.g. barrier, snapshot) across all wires - for i, s in enumerate(layer_str): + n_wires = len(config.wire_map) + for i, s in enumerate(layer_str[:n_wires]): layer_str[i] = s + label else: for w in op.wires: @@ -225,7 +226,8 @@ def _add_measurement(m, layer_str, config): meas_label = m.return_type.value if len(m.wires) == 0: # state or probability across all wires - for i, s in enumerate(layer_str): + n_wires = len(config.wire_map) + for i, s in enumerate(layer_str[:n_wires]): layer_str[i] = s + meas_label for w in m.wires: diff --git a/tests/drawer/test_draw.py b/tests/drawer/test_draw.py index 1bd7790f523..b571ed4fd0d 100644 --- a/tests/drawer/test_draw.py +++ b/tests/drawer/test_draw.py @@ -284,6 +284,8 @@ def circ(): class TestMidCircuitMeasurements: """Tests for drawing mid-circuit measurements and classical conditions.""" + # pylint: disable=too-many-public-methods + @pytest.mark.parametrize("device_name", ["default.qubit"]) def test_qnode_mid_circuit_measurement_not_deferred(self, device_name, mocker): """Test that a circuit containing mid-circuit measurements is transformed by the drawer @@ -329,6 +331,51 @@ def func(): assert drawing == expected_drawing + @pytest.mark.parametrize( + "op", [qml.GlobalPhase(0.1), qml.Identity(), qml.Snapshot(), qml.Barrier()] + ) + @pytest.mark.parametrize("decimals", [None, 2]) + def test_draw_all_wire_ops(self, op, decimals): + """Test that operators acting on all wires are drawn correctly""" + + def func(): + qml.X(0) + qml.X(1) + m = qml.measure(0) + qml.cond(m, qml.X)(0) + qml.apply(op) + return qml.expval(qml.Z(0)) + + # Stripping to remove trailing white-space because length of white-space at the + # end of the drawing depends on the length of each individual line + drawing = qml.draw(func, decimals=decimals)().strip() + label = op.label(decimals=decimals).replace("\n", "") + expected_drawing = ( + f"0: ──X──┤↗├──X──{label}─┤ \n1: ──X───║───║──{label}─┤ \n ╚═══╝" + ) + + assert drawing == expected_drawing + + @pytest.mark.parametrize( + "mp, label", [(qml.sample(), "Sample"), (qml.probs(), "Probs"), (qml.counts(), "Counts")] + ) + def test_draw_all_wire_measurements(self, mp, label): + """Test that operators acting on all wires are drawn correctly""" + + def func(): + qml.X(0) + qml.X(1) + m = qml.measure(0) + qml.cond(m, qml.X)(0) + return qml.apply(mp) + + # Stripping to remove trailing white-space because length of white-space at the + # end of the drawing depends on the length of each individual line + drawing = qml.draw(func)().strip() + expected_drawing = f"0: ──X──┤↗├──X─┤ {label}\n1: ──X───║───║─┤ {label}\n ╚═══╝" + + assert drawing == expected_drawing + def test_draw_mid_circuit_measurement_multiple_wires(self): """Test that mid-circuit measurements are correctly drawn in circuits with multiple wires."""