diff --git a/pennylane/debugging.py b/pennylane/debugging.py index 9db341e5eee..21ff5be40f3 100644 --- a/pennylane/debugging.py +++ b/pennylane/debugging.py @@ -14,6 +14,7 @@ """ This module contains functionality for debugging quantum programs on simulator devices. """ +import copy import pdb import sys from contextlib import contextmanager @@ -186,6 +187,23 @@ def reset_active_dev(cls): """Reset the global active device variable to None.""" cls.__active_dev = None + @classmethod + def _execute(cls, batch_tapes): + """Execute tape on the active device""" + dev = cls.get_active_device() + + valid_batch = batch_tapes + if dev.wires: + valid_batch = qml.devices.preprocess.validate_device_wires( + batch_tapes, wires=dev.wires + )[0] + + program, new_config = dev.preprocess() + new_batch, fn = program(valid_batch) + + # TODO: remove [0] index once compatible with transforms + return fn(dev.execute(new_batch, new_config))[0] + @contextmanager def pldb_device_manager(device): @@ -208,3 +226,82 @@ def breakpoint(): debugger = PLDB(skip=["pennylane.*"]) # skip internals when stepping through trace debugger.set_trace(sys._getframe().f_back) # pylint: disable=protected-access + + +def state(): + """Compute the state of the quantum circuit. + + Returns: + Array(complex): quantum state of the circuit. + """ + with qml.queuing.QueuingManager.stop_recording(): + m = qml.state() + + return _measure(m) + + +def expval(op): + """Compute the expectation value of an observable. + + Args: + op (Operator): the observable to compute the expectation value for + + Returns: + complex: expectation value of the operator + """ + + qml.queuing.QueuingManager.active_context().remove(op) # ensure we didn't accidentally queue op + + with qml.queuing.QueuingManager.stop_recording(): + m = qml.expval(op) + + return _measure(m) + + +def probs(wires=None, op=None): + """Compute the probability distribution for the state. + Args: + wires (Union[Iterable, int, str, list]): the wires the operation acts on + op (Union[Observable, MeasurementValue]): observable (with a ``diagonalizing_gates`` + attribute) that rotates the computational basis, or a ``MeasurementValue`` + corresponding to mid-circuit measurements. + + Returns: + Array(float): the probability distribution of the bitstrings for the wires + """ + if op: + qml.queuing.QueuingManager.active_context().remove( + op + ) # ensure we didn't accidentally queue op + + with qml.queuing.QueuingManager.stop_recording(): + m = qml.probs(wires, op) + + return _measure(m) + + +def _measure(measurement): + """Perform the measurement. + + Args: + measurement (MeasurementProcess): the type of measurement to be performed + + Returns: + tuple(complex): results from the measurement + """ + active_queue = qml.queuing.QueuingManager.active_context() + copied_queue = copy.deepcopy(active_queue) + + copied_queue.append(measurement) + qtape = qml.tape.QuantumScript.from_queue(copied_queue) + return PLDB._execute((qtape,)) # pylint: disable=protected-access + + +def tape(): + """Access the quantum tape of the circuit. + + Returns: + QuantumScript: the quantum tape representing the circuit + """ + active_queue = qml.queuing.QueuingManager.active_context() + return qml.tape.QuantumScript.from_queue(active_queue) diff --git a/tests/test_debugging.py b/tests/test_debugging.py index 0fd21099202..30d49f6d387 100644 --- a/tests/test_debugging.py +++ b/tests/test_debugging.py @@ -20,6 +20,7 @@ import pytest import pennylane as qml +from pennylane import numpy as qnp from pennylane.debugging import PLDB, pldb_device_manager @@ -427,7 +428,7 @@ def my_qfunc(): def test_valid_context_not_compatible_device(self): """Test that valid_context raises an error when breakpoint - is called with an un-supported device.""" + is called with an incompatible device.""" dev = qml.device("default.mixed", wires=2) @qml.qnode(dev) @@ -503,6 +504,147 @@ def test_has_active_device(self): PLDB.reset_active_dev() assert not PLDB.has_active_dev() + tapes = ( + qml.tape.QuantumScript( + ops=[qml.Hadamard(0), qml.CNOT([0, 1])], + measurements=[qml.state()], + ), + qml.tape.QuantumScript( + ops=[qml.Hadamard(0), qml.X(1)], + measurements=[qml.expval(qml.Z(1))], + ), + qml.tape.QuantumScript( + ops=[qml.Hadamard(0), qml.CNOT([0, 1])], + measurements=[qml.probs()], + ), + qml.tape.QuantumScript( + ops=[qml.Hadamard(0), qml.CNOT([0, 1])], + measurements=[qml.probs(wires=[0])], + ), + qml.tape.QuantumScript( + ops=[qml.Hadamard(0)], + measurements=[qml.state()], + ), # Test that state expands to number of device wires + ) + + results = ( + qnp.array([1 / qnp.sqrt(2), 0, 0, 1 / qnp.sqrt(2)], dtype=complex), + qnp.array(-1), + qnp.array([1 / 2, 0, 0, 1 / 2]), + qnp.array([1 / 2, 1 / 2]), + qnp.array([1 / qnp.sqrt(2), 0, 1 / qnp.sqrt(2), 0], dtype=complex), + ) + + @pytest.mark.parametrize("tape, expected_result", zip(tapes, results)) + @pytest.mark.parametrize( + "dev", (qml.device("default.qubit", wires=2), qml.device("lightning.qubit", wires=2)) + ) + def test_execute(self, dev, tape, expected_result): + """Test that the _execute method works as expected.""" + PLDB.add_device(dev) + executed_results = PLDB._execute((tape,)) + assert qnp.allclose(expected_result, executed_results) + PLDB.reset_active_dev() + + +def test_tape(): + """Test that we can access the tape from the active queue.""" + with qml.queuing.AnnotatedQueue() as queue: + qml.X(0) + + for i in range(3): + qml.Hadamard(i) + + qml.Y(1) + qml.Z(0) + qml.expval(qml.Z(0)) + + executed_tape = qml.debugging.tape() + + expected_tape = qml.tape.QuantumScript.from_queue(queue) + assert qml.equal(expected_tape, executed_tape) + + +@pytest.mark.parametrize("measurement_process", (qml.expval(qml.Z(0)), qml.state(), qml.probs())) +@patch.object(PLDB, "_execute") +def test_measure(mock_method, measurement_process): + """Test that the private measure function doesn't modify the active queue""" + with qml.queuing.AnnotatedQueue() as queue: + ops = [qml.X(0), qml.Y(1), qml.Z(0)] + [qml.Hadamard(i) for i in range(3)] + measurements = [qml.expval(qml.X(2)), qml.state(), qml.probs(), qml.var(qml.Z(3))] + qml.debugging._measure(measurement_process) + + executed_tape = qml.tape.QuantumScript.from_queue(queue) + expected_tape = qml.tape.QuantumScript(ops, measurements) + + assert qml.equal(expected_tape, executed_tape) # no unexpected queuing + + expected_debugging_tape = qml.tape.QuantumScript(ops, measurements + [measurement_process]) + executed_debugging_tape = mock_method.call_args.args[0][0] + + assert qml.equal( + expected_debugging_tape, executed_debugging_tape + ) # _execute was called with new measurements + + +@patch.object(PLDB, "_execute") +def test_state(_mock_method): + """Test that the state function works as expected.""" + with qml.queuing.AnnotatedQueue() as queue: + qml.RX(1.23, 0) + qml.RY(0.45, 2) + qml.sample() + + qml.debugging.state() + + assert qml.state() not in queue + + +@patch.object(PLDB, "_execute") +def test_expval(_mock_method): + """Test that the expval function works as expected.""" + for op in [qml.X(0), qml.Y(1), qml.Z(2), qml.Hadamard(0)]: + with qml.queuing.AnnotatedQueue() as queue: + qml.RX(1.23, 0) + qml.RY(0.45, 2) + qml.sample() + + qml.debugging.expval(op) + + assert op not in queue + assert qml.expval(op) not in queue + + +@patch.object(PLDB, "_execute") +def test_probs_with_op(_mock_method): + """Test that the probs function works as expected.""" + + for op in [None, qml.X(0), qml.Y(1), qml.Z(2)]: + with qml.queuing.AnnotatedQueue() as queue: + qml.RX(1.23, 0) + qml.RY(0.45, 2) + qml.sample() + + qml.debugging.probs(op=op) + + assert op not in queue + assert qml.probs(op=op) not in queue + + +@patch.object(PLDB, "_execute") +def test_probs_with_wires(_mock_method): + """Test that the probs function works as expected.""" + + for wires in [None, [0, 1], [2]]: + with qml.queuing.AnnotatedQueue() as queue: + qml.RX(1.23, 0) + qml.RY(0.45, 2) + qml.sample() + + qml.debugging.probs(wires=wires) + + assert qml.probs(wires=wires) not in queue + @pytest.mark.parametrize("device_name", ("default.qubit", "lightning.qubit")) def test_pldb_device_manager(device_name):