Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dynamic measurements with debugger #5749

Merged
merged 39 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
20eef13
implement breakpoint()
Jaybsoni Apr 30, 2024
c814e7e
added docstring and tests
Jaybsoni Apr 30, 2024
f120900
Merge branch 'master' into algo_debug_breakpoint
Jaybsoni Apr 30, 2024
508bb8b
format
Jaybsoni Apr 30, 2024
3d7de97
Merge branch 'master' into algo_debug_breakpoint
Jaybsoni May 7, 2024
18bffe2
Add missing test coverage
Jaybsoni May 10, 2024
f82ea4d
Merge branch 'master' into algo_debug_breakpoint
Jaybsoni May 10, 2024
86829ee
Merge branch 'master' into algo_debug_breakpoint
Jaybsoni May 21, 2024
ad62173
Merge branch 'master' into algo_debug_breakpoint
Jaybsoni May 23, 2024
f45dc3e
Merge branch 'master' into algo_debug_breakpoint
Jaybsoni May 23, 2024
26f28f9
Merge branch 'master' into algo_debug_breakpoint
Jaybsoni May 27, 2024
ac476fa
Added support for measurements
Jaybsoni May 28, 2024
5b1ad2c
Added tests
Jaybsoni May 28, 2024
f7feadb
Address code review comments
Jaybsoni May 30, 2024
2d7bfbf
Merge branch 'algo_debug_breakpoint' into algo_debug_measurements
Jaybsoni May 30, 2024
2e00723
add context manager + address code review comments
Jaybsoni May 30, 2024
9b9c99b
merge
Jaybsoni May 30, 2024
a897a47
[skip ci]
Jaybsoni May 30, 2024
8cd8255
merge [skip ci]
Jaybsoni May 30, 2024
6751e87
fix typo
Jaybsoni May 31, 2024
f537c39
pull class method back into class
Jaybsoni May 31, 2024
74a2292
Add probs support
Jaybsoni Jun 3, 2024
9ef7b88
[skip ci]
Jaybsoni Jun 3, 2024
bdf7ba8
[skip ci]
Jaybsoni Jun 3, 2024
eb8e852
fix output, [skip ci]
Jaybsoni Jun 3, 2024
3f9585d
expand state to device wires
Jaybsoni Jun 4, 2024
13f55f0
Apply suggestions from code review
Jaybsoni Jun 5, 2024
18b75df
[skip ci]
Jaybsoni Jun 5, 2024
7b31827
Addressed code review comments
Jaybsoni Jun 5, 2024
84890ff
format tests
Jaybsoni Jun 5, 2024
1d03db1
fix bugs
Jaybsoni Jun 5, 2024
1ad2d3c
merge base branch, [skip ci]
Jaybsoni Jun 5, 2024
7ca926e
missing test, [skip ci]
Jaybsoni Jun 5, 2024
b37c117
Merge branch 'debugging_feature' into algo_debug_measurements
Jaybsoni Jun 6, 2024
246a12b
format, [skip ci]
Jaybsoni Jun 6, 2024
1803c0b
address code review comments
Jaybsoni Jun 6, 2024
d554e23
Apply suggestions from code review
Jaybsoni Jun 7, 2024
ebfbc7f
address code review comments
Jaybsoni Jun 7, 2024
aa7d354
lint + format
Jaybsoni Jun 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions pennylane/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
This module contains functionality for debugging quantum programs on simulator devices.
"""
import copy
import pdb
import sys
from contextlib import contextmanager
Expand Down Expand Up @@ -186,6 +187,23 @@ def reset_active_dev(cls):
"""Reset the global active device variable to None."""
cls.__active_dev = None

@classmethod
def _execute(cls, batch_tapes):
"""Execute tape on the active device"""
dev = cls.get_active_device()

valid_batch = batch_tapes
if dev.wires:
valid_batch = qml.devices.preprocess.validate_device_wires(
batch_tapes, wires=dev.wires
)[0]

program, new_config = dev.preprocess()
new_batch, fn = program(valid_batch)

# TODO: remove [0] index once compatible with transforms
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
return fn(dev.execute(new_batch, new_config))[0]


@contextmanager
def pldb_device_manager(device):
Expand All @@ -208,3 +226,82 @@ def breakpoint():

debugger = PLDB(skip=["pennylane.*"]) # skip internals when stepping through trace
debugger.set_trace(sys._getframe().f_back) # pylint: disable=protected-access


def state():
"""Compute the state of the quantum circuit.

Returns:
Array(complex): quantum state of the circuit.
"""
with qml.queuing.QueuingManager.stop_recording():
m = qml.state()

return _measure(m)


def expval(op):
"""Compute the expectation value of an observable.

Args:
op (Operator): the observable to compute the expectation value for

Returns:
complex: expectation value of the operator
"""

qml.queuing.QueuingManager.active_context().remove(op) # ensure we didn't accidentally queue op
soranjh marked this conversation as resolved.
Show resolved Hide resolved

with qml.queuing.QueuingManager.stop_recording():
m = qml.expval(op)

return _measure(m)


def probs(wires=None, op=None):
"""Compute the probability distribution for the state.
Args:
wires (Union[Iterable, int, str, list]): the wires the operation acts on
op (Union[Observable, MeasurementValue]): observable (with a ``diagonalizing_gates``
attribute) that rotates the computational basis, or a ``MeasurementValue``
corresponding to mid-circuit measurements.

Returns:
Array(float): the probability distribution of the bitstrings for the wires
"""
if op:
qml.queuing.QueuingManager.active_context().remove(
op
) # ensure we didn't accidentally queue op

with qml.queuing.QueuingManager.stop_recording():
m = qml.probs(wires, op)

return _measure(m)


def _measure(measurement):
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
"""Perform the measurement.

Args:
measurement (MeasurementProcess): the type of measurement to be performed

Returns:
tuple(complex): results from the measurement
"""
active_queue = qml.queuing.QueuingManager.active_context()
copied_queue = copy.deepcopy(active_queue)

copied_queue.append(measurement)
qtape = qml.tape.QuantumScript.from_queue(copied_queue)
return PLDB._execute((qtape,)) # pylint: disable=protected-access


def tape():
soranjh marked this conversation as resolved.
Show resolved Hide resolved
"""Access the quantum tape of the circuit.

Returns:
QuantumScript: the quantum tape representing the circuit
"""
active_queue = qml.queuing.QueuingManager.active_context()
return qml.tape.QuantumScript.from_queue(active_queue)
144 changes: 143 additions & 1 deletion tests/test_debugging.py
soranjh marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest

import pennylane as qml
from pennylane import numpy as qnp
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
from pennylane.debugging import PLDB, pldb_device_manager


Expand Down Expand Up @@ -427,7 +428,7 @@ def my_qfunc():

def test_valid_context_not_compatible_device(self):
"""Test that valid_context raises an error when breakpoint
is called with an un-supported device."""
is called with an incompatible device."""
dev = qml.device("default.mixed", wires=2)

@qml.qnode(dev)
Expand Down Expand Up @@ -503,6 +504,147 @@ def test_has_active_device(self):
PLDB.reset_active_dev()
assert not PLDB.has_active_dev()

tapes = (
qml.tape.QuantumScript(
ops=[qml.Hadamard(0), qml.CNOT([0, 1])],
measurements=[qml.state()],
),
qml.tape.QuantumScript(
ops=[qml.Hadamard(0), qml.X(1)],
measurements=[qml.expval(qml.Z(1))],
),
qml.tape.QuantumScript(
ops=[qml.Hadamard(0), qml.CNOT([0, 1])],
measurements=[qml.probs()],
),
qml.tape.QuantumScript(
ops=[qml.Hadamard(0), qml.CNOT([0, 1])],
measurements=[qml.probs(wires=[0])],
),
qml.tape.QuantumScript(
ops=[qml.Hadamard(0)],
measurements=[qml.state()],
), # Test that state expands to number of device wires
soranjh marked this conversation as resolved.
Show resolved Hide resolved
)

results = (
qnp.array([1 / qnp.sqrt(2), 0, 0, 1 / qnp.sqrt(2)], dtype=complex),
qnp.array(-1),
qnp.array([1 / 2, 0, 0, 1 / 2]),
qnp.array([1 / 2, 1 / 2]),
qnp.array([1 / qnp.sqrt(2), 0, 1 / qnp.sqrt(2), 0], dtype=complex),
)

@pytest.mark.parametrize("tape, expected_result", zip(tapes, results))
@pytest.mark.parametrize(
"dev", (qml.device("default.qubit", wires=2), qml.device("lightning.qubit", wires=2))
)
def test_execute(self, dev, tape, expected_result):
"""Test that the _execute method works as expected."""
PLDB.add_device(dev)
executed_results = PLDB._execute((tape,))
assert qnp.allclose(expected_result, executed_results)
PLDB.reset_active_dev()


def test_tape():
"""Test that we can access the tape from the active queue."""
with qml.queuing.AnnotatedQueue() as queue:
qml.X(0)

for i in range(3):
qml.Hadamard(i)

qml.Y(1)
qml.Z(0)
qml.expval(qml.Z(0))

executed_tape = qml.debugging.tape()

expected_tape = qml.tape.QuantumScript.from_queue(queue)
assert qml.equal(expected_tape, executed_tape)


@pytest.mark.parametrize("measurement_process", (qml.expval(qml.Z(0)), qml.state(), qml.probs()))
@patch.object(PLDB, "_execute")
def test_measure(mock_method, measurement_process):
"""Test that the private measure function doesn't modify the active queue"""
with qml.queuing.AnnotatedQueue() as queue:
ops = [qml.X(0), qml.Y(1), qml.Z(0)] + [qml.Hadamard(i) for i in range(3)]
measurements = [qml.expval(qml.X(2)), qml.state(), qml.probs(), qml.var(qml.Z(3))]
qml.debugging._measure(measurement_process)

executed_tape = qml.tape.QuantumScript.from_queue(queue)
expected_tape = qml.tape.QuantumScript(ops, measurements)

assert qml.equal(expected_tape, executed_tape) # no unexpected queuing

expected_debugging_tape = qml.tape.QuantumScript(ops, measurements + [measurement_process])
executed_debugging_tape = mock_method.call_args.args[0][0]

assert qml.equal(
expected_debugging_tape, executed_debugging_tape
) # _execute was called with new measurements


@patch.object(PLDB, "_execute")
def test_state(_mock_method):
"""Test that the state function works as expected."""
with qml.queuing.AnnotatedQueue() as queue:
qml.RX(1.23, 0)
qml.RY(0.45, 2)
qml.sample()

qml.debugging.state()

assert qml.state() not in queue


@patch.object(PLDB, "_execute")
def test_expval(_mock_method):
"""Test that the expval function works as expected."""
for op in [qml.X(0), qml.Y(1), qml.Z(2), qml.Hadamard(0)]:
with qml.queuing.AnnotatedQueue() as queue:
qml.RX(1.23, 0)
qml.RY(0.45, 2)
qml.sample()

qml.debugging.expval(op)

assert op not in queue
assert qml.expval(op) not in queue


@patch.object(PLDB, "_execute")
def test_probs_with_op(_mock_method):
"""Test that the probs function works as expected."""

for op in [None, qml.X(0), qml.Y(1), qml.Z(2)]:
with qml.queuing.AnnotatedQueue() as queue:
qml.RX(1.23, 0)
qml.RY(0.45, 2)
qml.sample()

qml.debugging.probs(op=op)

assert op not in queue
assert qml.probs(op=op) not in queue


@patch.object(PLDB, "_execute")
def test_probs_with_wires(_mock_method):
"""Test that the probs function works as expected."""

for wires in [None, [0, 1], [2]]:
with qml.queuing.AnnotatedQueue() as queue:
qml.RX(1.23, 0)
qml.RY(0.45, 2)
qml.sample()

qml.debugging.probs(wires=wires)

assert qml.probs(wires=wires) not in queue


@pytest.mark.parametrize("device_name", ("default.qubit", "lightning.qubit"))
def test_pldb_device_manager(device_name):
Expand Down
Loading