Skip to content

Commit

Permalink
Support dynamic measurements with debugger (#5749)
Browse files Browse the repository at this point in the history
**Context:**
Follow up PR adding support for measurements

**Description of the Change:**

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**

---------

Co-authored-by: Mikhail Andrenkov <mikhail@xanadu.ai>
Co-authored-by: Utkarsh <utkarshazad98@gmail.com>
  • Loading branch information
3 people authored Jun 7, 2024
1 parent 1a652a7 commit cf7e012
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 1 deletion.
97 changes: 97 additions & 0 deletions pennylane/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
This module contains functionality for debugging quantum programs on simulator devices.
"""
import copy
import pdb
import sys
from contextlib import contextmanager
Expand Down Expand Up @@ -186,6 +187,23 @@ def reset_active_dev(cls):
"""Reset the global active device variable to None."""
cls.__active_dev = None

@classmethod
def _execute(cls, batch_tapes):
"""Execute tape on the active device"""
dev = cls.get_active_device()

valid_batch = batch_tapes
if dev.wires:
valid_batch = qml.devices.preprocess.validate_device_wires(
batch_tapes, wires=dev.wires
)[0]

program, new_config = dev.preprocess()
new_batch, fn = program(valid_batch)

# TODO: remove [0] index once compatible with transforms
return fn(dev.execute(new_batch, new_config))[0]


@contextmanager
def pldb_device_manager(device):
Expand All @@ -208,3 +226,82 @@ def breakpoint():

debugger = PLDB(skip=["pennylane.*"]) # skip internals when stepping through trace
debugger.set_trace(sys._getframe().f_back) # pylint: disable=protected-access


def state():
"""Compute the state of the quantum circuit.
Returns:
Array(complex): quantum state of the circuit.
"""
with qml.queuing.QueuingManager.stop_recording():
m = qml.state()

return _measure(m)


def expval(op):
"""Compute the expectation value of an observable.
Args:
op (Operator): the observable to compute the expectation value for
Returns:
complex: expectation value of the operator
"""

qml.queuing.QueuingManager.active_context().remove(op) # ensure we didn't accidentally queue op

with qml.queuing.QueuingManager.stop_recording():
m = qml.expval(op)

return _measure(m)


def probs(wires=None, op=None):
"""Compute the probability distribution for the state.
Args:
wires (Union[Iterable, int, str, list]): the wires the operation acts on
op (Union[Observable, MeasurementValue]): observable (with a ``diagonalizing_gates``
attribute) that rotates the computational basis, or a ``MeasurementValue``
corresponding to mid-circuit measurements.
Returns:
Array(float): the probability distribution of the bitstrings for the wires
"""
if op:
qml.queuing.QueuingManager.active_context().remove(
op
) # ensure we didn't accidentally queue op

with qml.queuing.QueuingManager.stop_recording():
m = qml.probs(wires, op)

return _measure(m)


def _measure(measurement):
"""Perform the measurement.
Args:
measurement (MeasurementProcess): the type of measurement to be performed
Returns:
tuple(complex): results from the measurement
"""
active_queue = qml.queuing.QueuingManager.active_context()
copied_queue = copy.deepcopy(active_queue)

copied_queue.append(measurement)
qtape = qml.tape.QuantumScript.from_queue(copied_queue)
return PLDB._execute((qtape,)) # pylint: disable=protected-access


def tape():
"""Access the quantum tape of the circuit.
Returns:
QuantumScript: the quantum tape representing the circuit
"""
active_queue = qml.queuing.QueuingManager.active_context()
return qml.tape.QuantumScript.from_queue(active_queue)
144 changes: 143 additions & 1 deletion tests/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest

import pennylane as qml
from pennylane import numpy as qnp
from pennylane.debugging import PLDB, pldb_device_manager


Expand Down Expand Up @@ -427,7 +428,7 @@ def my_qfunc():

def test_valid_context_not_compatible_device(self):
"""Test that valid_context raises an error when breakpoint
is called with an un-supported device."""
is called with an incompatible device."""
dev = qml.device("default.mixed", wires=2)

@qml.qnode(dev)
Expand Down Expand Up @@ -503,6 +504,147 @@ def test_has_active_device(self):
PLDB.reset_active_dev()
assert not PLDB.has_active_dev()

tapes = (
qml.tape.QuantumScript(
ops=[qml.Hadamard(0), qml.CNOT([0, 1])],
measurements=[qml.state()],
),
qml.tape.QuantumScript(
ops=[qml.Hadamard(0), qml.X(1)],
measurements=[qml.expval(qml.Z(1))],
),
qml.tape.QuantumScript(
ops=[qml.Hadamard(0), qml.CNOT([0, 1])],
measurements=[qml.probs()],
),
qml.tape.QuantumScript(
ops=[qml.Hadamard(0), qml.CNOT([0, 1])],
measurements=[qml.probs(wires=[0])],
),
qml.tape.QuantumScript(
ops=[qml.Hadamard(0)],
measurements=[qml.state()],
), # Test that state expands to number of device wires
)

results = (
qnp.array([1 / qnp.sqrt(2), 0, 0, 1 / qnp.sqrt(2)], dtype=complex),
qnp.array(-1),
qnp.array([1 / 2, 0, 0, 1 / 2]),
qnp.array([1 / 2, 1 / 2]),
qnp.array([1 / qnp.sqrt(2), 0, 1 / qnp.sqrt(2), 0], dtype=complex),
)

@pytest.mark.parametrize("tape, expected_result", zip(tapes, results))
@pytest.mark.parametrize(
"dev", (qml.device("default.qubit", wires=2), qml.device("lightning.qubit", wires=2))
)
def test_execute(self, dev, tape, expected_result):
"""Test that the _execute method works as expected."""
PLDB.add_device(dev)
executed_results = PLDB._execute((tape,))
assert qnp.allclose(expected_result, executed_results)
PLDB.reset_active_dev()


def test_tape():
"""Test that we can access the tape from the active queue."""
with qml.queuing.AnnotatedQueue() as queue:
qml.X(0)

for i in range(3):
qml.Hadamard(i)

qml.Y(1)
qml.Z(0)
qml.expval(qml.Z(0))

executed_tape = qml.debugging.tape()

expected_tape = qml.tape.QuantumScript.from_queue(queue)
assert qml.equal(expected_tape, executed_tape)


@pytest.mark.parametrize("measurement_process", (qml.expval(qml.Z(0)), qml.state(), qml.probs()))
@patch.object(PLDB, "_execute")
def test_measure(mock_method, measurement_process):
"""Test that the private measure function doesn't modify the active queue"""
with qml.queuing.AnnotatedQueue() as queue:
ops = [qml.X(0), qml.Y(1), qml.Z(0)] + [qml.Hadamard(i) for i in range(3)]
measurements = [qml.expval(qml.X(2)), qml.state(), qml.probs(), qml.var(qml.Z(3))]
qml.debugging._measure(measurement_process)

executed_tape = qml.tape.QuantumScript.from_queue(queue)
expected_tape = qml.tape.QuantumScript(ops, measurements)

assert qml.equal(expected_tape, executed_tape) # no unexpected queuing

expected_debugging_tape = qml.tape.QuantumScript(ops, measurements + [measurement_process])
executed_debugging_tape = mock_method.call_args.args[0][0]

assert qml.equal(
expected_debugging_tape, executed_debugging_tape
) # _execute was called with new measurements


@patch.object(PLDB, "_execute")
def test_state(_mock_method):
"""Test that the state function works as expected."""
with qml.queuing.AnnotatedQueue() as queue:
qml.RX(1.23, 0)
qml.RY(0.45, 2)
qml.sample()

qml.debugging.state()

assert qml.state() not in queue


@patch.object(PLDB, "_execute")
def test_expval(_mock_method):
"""Test that the expval function works as expected."""
for op in [qml.X(0), qml.Y(1), qml.Z(2), qml.Hadamard(0)]:
with qml.queuing.AnnotatedQueue() as queue:
qml.RX(1.23, 0)
qml.RY(0.45, 2)
qml.sample()

qml.debugging.expval(op)

assert op not in queue
assert qml.expval(op) not in queue


@patch.object(PLDB, "_execute")
def test_probs_with_op(_mock_method):
"""Test that the probs function works as expected."""

for op in [None, qml.X(0), qml.Y(1), qml.Z(2)]:
with qml.queuing.AnnotatedQueue() as queue:
qml.RX(1.23, 0)
qml.RY(0.45, 2)
qml.sample()

qml.debugging.probs(op=op)

assert op not in queue
assert qml.probs(op=op) not in queue


@patch.object(PLDB, "_execute")
def test_probs_with_wires(_mock_method):
"""Test that the probs function works as expected."""

for wires in [None, [0, 1], [2]]:
with qml.queuing.AnnotatedQueue() as queue:
qml.RX(1.23, 0)
qml.RY(0.45, 2)
qml.sample()

qml.debugging.probs(wires=wires)

assert qml.probs(wires=wires) not in queue


@pytest.mark.parametrize("device_name", ("default.qubit", "lightning.qubit"))
def test_pldb_device_manager(device_name):
Expand Down

0 comments on commit cf7e012

Please sign in to comment.