Skip to content

Commit

Permalink
Operations/measurements on all wires are drawn correctly with mid-cir…
Browse files Browse the repository at this point in the history
…cuit measurements (#5501)

**Context:**
Bug fix for #5499 

**Description of the Change:**
Adding labels to layer strings only extends to wires, not classical
bits.

**Benefits:**
Drawer works correctly with mid-circuit measurements

**Possible Drawbacks:**

**Related GitHub Issues:**
  • Loading branch information
mudit2812 committed Apr 12, 2024
1 parent 46afaa5 commit 3f52824
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@

<h3>Bug fixes 🐛</h3>

* Operators applied to all wires are now drawn correctly in a circuit with mid-circuit measurements.
[(#5501)](https://github.com/PennyLaneAI/pennylane/pull/5501)

* Fix a bug where certain unary mid-circuit measurement expressions would raise an uncaught error.
[(#5480)](https://github.com/PennyLaneAI/pennylane/pull/5480)

Expand Down
6 changes: 4 additions & 2 deletions pennylane/drawer/tape_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def _add_op(op, layer_str, config):

label = op.label(decimals=config.decimals, cache=config.cache).replace("\n", "")
if len(op.wires) == 0: # operation (e.g. barrier, snapshot) across all wires
for i, s in enumerate(layer_str):
n_wires = len(config.wire_map)
for i, s in enumerate(layer_str[:n_wires]):
layer_str[i] = s + label
else:
for w in op.wires:
Expand Down Expand Up @@ -225,7 +226,8 @@ def _add_measurement(m, layer_str, config):
meas_label = m.return_type.value

if len(m.wires) == 0: # state or probability across all wires
for i, s in enumerate(layer_str):
n_wires = len(config.wire_map)
for i, s in enumerate(layer_str[:n_wires]):
layer_str[i] = s + meas_label

for w in m.wires:
Expand Down
47 changes: 47 additions & 0 deletions tests/drawer/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ def circ():
class TestMidCircuitMeasurements:
"""Tests for drawing mid-circuit measurements and classical conditions."""

# pylint: disable=too-many-public-methods

@pytest.mark.parametrize("device_name", ["default.qubit"])
def test_qnode_mid_circuit_measurement_not_deferred(self, device_name, mocker):
"""Test that a circuit containing mid-circuit measurements is transformed by the drawer
Expand Down Expand Up @@ -329,6 +331,51 @@ def func():

assert drawing == expected_drawing

@pytest.mark.parametrize(
"op", [qml.GlobalPhase(0.1), qml.Identity(), qml.Snapshot(), qml.Barrier()]
)
@pytest.mark.parametrize("decimals", [None, 2])
def test_draw_all_wire_ops(self, op, decimals):
"""Test that operators acting on all wires are drawn correctly"""

def func():
qml.X(0)
qml.X(1)
m = qml.measure(0)
qml.cond(m, qml.X)(0)
qml.apply(op)
return qml.expval(qml.Z(0))

# Stripping to remove trailing white-space because length of white-space at the
# end of the drawing depends on the length of each individual line
drawing = qml.draw(func, decimals=decimals)().strip()
label = op.label(decimals=decimals).replace("\n", "")
expected_drawing = (
f"0: ──X──┤↗├──X──{label}─┤ <Z>\n1: ──X───║───║──{label}─┤ \n ╚═══╝"
)

assert drawing == expected_drawing

@pytest.mark.parametrize(
"mp, label", [(qml.sample(), "Sample"), (qml.probs(), "Probs"), (qml.counts(), "Counts")]
)
def test_draw_all_wire_measurements(self, mp, label):
"""Test that operators acting on all wires are drawn correctly"""

def func():
qml.X(0)
qml.X(1)
m = qml.measure(0)
qml.cond(m, qml.X)(0)
return qml.apply(mp)

# Stripping to remove trailing white-space because length of white-space at the
# end of the drawing depends on the length of each individual line
drawing = qml.draw(func)().strip()
expected_drawing = f"0: ──X──┤↗├──X─┤ {label}\n1: ──X───║───║─┤ {label}\n ╚═══╝"

assert drawing == expected_drawing

def test_draw_mid_circuit_measurement_multiple_wires(self):
"""Test that mid-circuit measurements are correctly drawn in circuits
with multiple wires."""
Expand Down

0 comments on commit 3f52824

Please sign in to comment.