Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QNode config for mid-circuit measurement options #5679

Merged
merged 59 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
4f5ce6a
Added functionality for mcm config
mudit2812 May 10, 2024
3e65b57
Added qjit error
mudit2812 May 16, 2024
bf88883
Update qnode signature
mudit2812 May 21, 2024
7c37074
Minor changes to tidy up
mudit2812 May 21, 2024
b0db920
[skip ci] update changelog
mudit2812 May 21, 2024
a8df5ed
Merge branch 'master' into postselect-choice
mudit2812 May 21, 2024
b20ee65
Update qnode processing for old devices
mudit2812 May 21, 2024
173d03f
[skip ci] Skip CI
mudit2812 May 21, 2024
bb0db18
Reverting debugging changes
mudit2812 May 22, 2024
029ac2a
Fix interface check in qml.execute
mudit2812 May 22, 2024
cd21934
Added preprocess tests
mudit2812 May 22, 2024
fdbaf42
Added defer_measurements tests
mudit2812 May 22, 2024
0a9f52e
Added qnode tests
mudit2812 May 22, 2024
3c641f9
Merge branch 'master' into postselect-choice
mudit2812 May 22, 2024
79d078e
Fixed qnode test
mudit2812 May 23, 2024
4319f73
Update simulate to map wires before simulation
mudit2812 May 23, 2024
ac16f17
Update simulation functions to not map wires and make it a precondition
mudit2812 May 23, 2024
375c35f
Merge branch 'master' into postselect-choice
mudit2812 May 23, 2024
26b1b72
Merge branch 'master' into postselect-choice
mudit2812 May 23, 2024
6e6cb0c
Added more tests for code coverage
mudit2812 May 23, 2024
77614fc
Updated postselect_shots to postselect_mode
mudit2812 May 23, 2024
951254f
Merge branch 'master' into postselect-choice
mudit2812 May 23, 2024
fcdfee6
Added docs
mudit2812 May 24, 2024
83a658e
Merge branch 'master' into postselect-choice
mudit2812 May 24, 2024
60912d2
Fix docs
mudit2812 May 24, 2024
dab94bb
Fix indentation in docs
mudit2812 May 24, 2024
a40e845
Fix indentation again
mudit2812 May 24, 2024
9b14d37
Another indentation fix..
mudit2812 May 24, 2024
6abdd02
Made mcm_config dataclass
mudit2812 May 27, 2024
356f4cf
Merge branch 'master' into postselect-choice
mudit2812 May 27, 2024
73ee6e2
Fix linting after merge issues
mudit2812 May 27, 2024
9218ec1
Fixing MCMConfig intialization
mudit2812 May 27, 2024
8548502
Added fill-shots support to dynamic_one_shot; fixing type hints
mudit2812 May 28, 2024
8e6cb55
Added test for jax with postselect_mode='hw-like' and one-shot
mudit2812 May 28, 2024
3056cb8
Updated docs
mudit2812 May 28, 2024
6069b10
Add info about defaults to docs
mudit2812 May 28, 2024
8d03b8d
Merge branch 'master' into postselect-choice
mudit2812 May 28, 2024
a4f40e5
Removed failing test
mudit2812 May 28, 2024
c6d2a4c
Updated docs
mudit2812 May 29, 2024
91f2be0
Merge branch 'master' into postselect-choice
mudit2812 May 29, 2024
d8d60bb
Update pennylane/devices/execution_config.py
mudit2812 May 29, 2024
70bd23d
Remove qjit checks
mudit2812 May 29, 2024
c19dddd
Fixing linting error
mudit2812 May 29, 2024
9dde1ee
Fix execution config; add docs per code review
mudit2812 May 29, 2024
88d9834
Addressing code review; replacing some warnings with errors
mudit2812 May 30, 2024
b809cd2
Fixed defer_measurements fill-shots; added functionality for jax
mudit2812 May 30, 2024
0db09e3
Merge branch 'master' into postselect-choice
mudit2812 May 31, 2024
85258d4
Update old device API MCM config support
mudit2812 Jun 3, 2024
e0555e5
Default postselect_mode=None, raise error with DM+jax_jit+hw_like, docs
mudit2812 Jun 3, 2024
bea44ba
Found a way to raise error with diff_method=None
mudit2812 Jun 3, 2024
ac78db5
Update test doc
mudit2812 Jun 3, 2024
62e2d4b
Added additional condition to qml.execute
mudit2812 Jun 3, 2024
cbdbd5c
Added dev comment
mudit2812 Jun 3, 2024
851cadf
Merge branch 'master' into postselect-choice
mudit2812 Jun 4, 2024
47399e2
[skip ci] Skip CI
mudit2812 Jun 4, 2024
31857c8
Addressing code review; fixing tests
mudit2812 Jun 4, 2024
ef35642
Fixing code cov
mudit2812 Jun 4, 2024
d7eeb1d
Add coverage; update docs
mudit2812 Jun 4, 2024
2e3e3a0
doc fix
mudit2812 Jun 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions doc/introduction/measurements.rst
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,86 @@ Collecting statistics for sequences of mid-circuit measurements is supported wit
When collecting statistics for a list of mid-circuit measurements, values manipulated using
arithmetic operators should not be used as this behaviour is not supported.

Configuring mid-circuit measurements
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
************************************

As seen above, there are multiple ways in which circuits with mid-circuit measurements can be executed with
PennyLane. For ease of use, we provide the following configuration options to users when initializing a
:class:`~pennylane.QNode`:

* ``mcm_method``: To set the method used for applying mid-circuit measurements. Use ``mcm_method="deferred"``
to use the deferred measurements principle or ``mcm_method="one-shot"`` to use the one-shot transform as
described above. When executing with finite shots, ``mcm_method="one-shot"`` will be the default, and
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
``mcm_method="deferred"`` otherwise.

.. warning::

If the ``mcm_method`` argument is provided, the :func:`~pennylane.defer_measurements` or
:func:`~pennylane.dynamic_one_shot` transforms must not be applied directly to the :class:`~pennylane.QNode`
as it can lead to incorrect behaviour.

* ``postselect_mode``: To configure how invalid shots are handled when postselecting mid-circuit measurements
with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid samples. In this case, the number
of samples that are used for processing results can be less than the total number of shots. If
``postselect_mode="fill-shots"`` is used, then the postselected value will be sampled unconditionally, and all
samples will be valid. This is equivalent to sampling until the number of valid samples matches the total number
of shots. The default behaviour is ``postselect_mode="hw-like"``.
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python3

import pennylane as qml
import numpy as np

dev = qml.device("default.qubit", wires=3, shots=10)

def circuit(x):
qml.RX(x, 0)
m0 = qml.measure(0, postselect=1)
qml.CNOT([0, 1])
return qml.sample(qml.PauliZ(0))

fill_shots_qnode = qml.QNode(circuit, dev, mcm_method="one-shot", postselect_mode="fill-shots")
hw_like_qnode = qml.QNode(circuit, dev, mcm_method="one-shot", postselect_mode="hw-like")

>>> fill_shots_qnode(np.pi / 2)
array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.])
>>> hw_like_qnode(np.pi / 2)
array([-1., -1., -1., -1., -1., -1., -1.])

.. note::

When using the ``jax`` interface, ``postselect_mode="hw-like"`` will have different behaviour based on the
chosen ``mcm_method``.

* If ``mcm_method="one-shot"``, invalid shots will not be discarded. Instead, invalid samples will be replaced
by ``np.iinfo(np.int32).min``. These invalid samples will not be used for processing final results (like
expectation values), but will appear in the ``QNode`` output if samples are requested directly. Consider
the circuit below:

.. code-block:: python3

import pennylane as qml
import jax
import jax.numpy as jnp

dev = qml.device("default.qubit", wires=3, shots=10, seed=jax.random.PRNGKey(123))

@qml.qnode(dev, postselect_mode="hw-like", mcm_method="one-shot")
def circuit(x):
qml.RX(x, 0)
qml.measure(0, postselect=1)
return qml.sample(qml.PauliZ(0))

>>> x = jnp.array(1.8)
>>> f(x)
Array([-2.1474836e+09, -1.0000000e+00, -2.1474836e+09, -2.1474836e+09,
-1.0000000e+00, -2.1474836e+09, -1.0000000e+00, -2.1474836e+09,
-1.0000000e+00, -1.0000000e+00], dtype=float32, weak_type=True)

* If ``mcm_method="deferred"``, then using ``postselect_mode="hw-like"`` will have the same behaviour as when
``postselect_mode="fill-shots"``. This is due to the limitations of the :func:`~pennylane.defer_measurements`
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
transform, and this behaviour will change in the future to be more consistent with ``mcm_method="one-shot"``.

Changing the number of shots
----------------------------

Expand Down
15 changes: 14 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@

<h3>New features since last release</h3>

* `qml.QNode` and `qml.qnode` now accept two new keyword arguments: `postselect_mode` and `mcm_method`.
These keyword arguments can be used to configure how the device should behave when running circuits with
mid-circuit measurements.
[(#5679)](https://github.com/PennyLaneAI/pennylane/pull/5679)

* `postselect_mode="hw-like"` will indicate to devices to discard invalid shots when postselecting
mid-circuit measurements. Use `postselect_mode="fill-shots"` to unconditionally sample the postselected
value, thus making all samples valid. This is equivalent to sampling until the number of valid samples
matches the total number of shots.
* `mcm_method` will indicate which strategy to use for running circuits with mid-circuit measurements.
Use `mcm_method="deferred"` to use the deferred measurements principle, or `mcm_method="one-shot"`
to execute once for each shot.

* The `default.tensor` device is introduced to perform tensor network simulation of a quantum circuit.
[(#5699)](https://github.com/PennyLaneAI/pennylane/pull/5699)

Expand Down Expand Up @@ -31,7 +44,7 @@

* The `dynamic_one_shot` transform can be compiled with `jax.jit`.
[(#5557)](https://github.com/PennyLaneAI/pennylane/pull/5557)

* When using `defer_measurements` with postselecting mid-circuit measurements, operations
that will never be active due to the postselected state are skipped in the transformed
quantum circuit. In addition, postselected controls are skipped, as they are evaluated
Expand Down
3 changes: 2 additions & 1 deletion pennylane/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
:toctree: api

ExecutionConfig
MCMConfig
Device
DefaultQubit
NullQubit
Expand Down Expand Up @@ -146,7 +147,7 @@ def execute(self, circuits, execution_config = qml.devices.DefaultExecutionConfi

"""

from .execution_config import ExecutionConfig, DefaultExecutionConfig
from .execution_config import ExecutionConfig, DefaultExecutionConfig, MCMConfig
from .device_api import Device
from .default_qubit import DefaultQubit

Expand Down
17 changes: 15 additions & 2 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,9 @@ def preprocess(
transform_program = TransformProgram()

transform_program.add_transform(validate_device_wires, self.wires, name=self.name)
transform_program.add_transform(mid_circuit_measurements, device=self)
transform_program.add_transform(
mid_circuit_measurements, device=self, mcm_config=config.mcm_config
)
transform_program.add_transform(
decompose,
stopping_condition=stopping_condition,
Expand Down Expand Up @@ -596,14 +598,22 @@ def execute(
"interface": interface,
"state_cache": self._state_cache,
"prng_key": _key,
"postselect_mode": execution_config.mcm_config.postselect_mode,
},
)
for c, _key in zip(circuits, prng_keys)
)

vanilla_circuits = convert_to_numpy_parameters(circuits)[0]
seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))
simulate_kwargs = [{"rng": _rng, "prng_key": _key} for _rng, _key in zip(seeds, prng_keys)]
simulate_kwargs = [
{
"rng": _rng,
"prng_key": _key,
"postselect_mode": execution_config.mcm_config.postselect_mode,
}
for _rng, _key in zip(seeds, prng_keys)
]

with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
exec_map = executor.map(_simulate_wrapper, vanilla_circuits, simulate_kwargs)
Expand Down Expand Up @@ -847,20 +857,23 @@ def _simulate_wrapper(circuit, kwargs):


def _adjoint_jac_wrapper(c, debugger=None):
c = c.map_to_standard_wires()
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
state, is_state_batched = get_final_state(c, debugger=debugger)
jac = adjoint_jacobian(c, state=state)
res = measure_final_state(c, state, is_state_batched)
return res, jac


def _adjoint_jvp_wrapper(c, t, debugger=None):
c = c.map_to_standard_wires()
state, is_state_batched = get_final_state(c, debugger=debugger)
jvp = adjoint_jvp(c, t, state=state)
res = measure_final_state(c, state, is_state_batched)
return res, jvp


def _adjoint_vjp_wrapper(c, t, debugger=None):
c = c.map_to_standard_wires()
state, is_state_batched = get_final_state(c, debugger=debugger)
vjp = adjoint_vjp(c, t, state=state)
res = measure_final_state(c, state, is_state_batched)
Expand Down
27 changes: 26 additions & 1 deletion pennylane/devices/execution_config.py
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,28 @@
Contains the :class:`ExecutionConfig` data class.
"""
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union

from pennylane.workflow import SUPPORTED_INTERFACES


@dataclass
class MCMConfig:
"""A class to store mid-circuit measurement configurations."""

mcm_method: Optional[str] = None
"""Which mid-circuit measurement strategy to use. Use ``deferred`` for the deferred
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
measurements principle and "one-shot" if using finite shots to execute the circuit
for each shot separately. If not specified, the device will decide which method to
use."""

postselect_mode: str = "hw-like"
"""Configuration for handling shots with mid-circuit measurement postselection. If
``"hw-like"``, invalid shots will be discarded and only results for valid shots will
be returned. If ``"fill-shots"``, results corresponding to the original number of
shots will be returned."""

mudit2812 marked this conversation as resolved.
Show resolved Hide resolved

# pylint: disable=too-many-instance-attributes
@dataclass
class ExecutionConfig:
Expand Down Expand Up @@ -67,6 +84,9 @@ class ExecutionConfig:
derivative_order: int = 1
"""The derivative order to compute while evaluating a gradient"""

mcm_config: Union[MCMConfig, dict] = MCMConfig()
"""Configuration options for handling mid-circuit measurements"""

def __post_init__(self):
"""
Validate the configured execution options.
Expand All @@ -89,5 +109,10 @@ def __post_init__(self):
if self.gradient_keyword_arguments is None:
self.gradient_keyword_arguments = {}

if isinstance(self.mcm_config, dict):
self.mcm_config = MCMConfig(**self.mcm_config)
elif not isinstance(self.mcm_config, MCMConfig):
raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'")


DefaultExecutionConfig = ExecutionConfig()
28 changes: 18 additions & 10 deletions pennylane/devices/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from pennylane.typing import Result, ResultBatch
from pennylane.wires import WireError

from .execution_config import MCMConfig

PostprocessingFn = Callable[[ResultBatch], Union[Result, ResultBatch]]


Expand Down Expand Up @@ -80,7 +82,7 @@ def _operator_decomposition_gen(
@transform
def no_sampling(
tape: qml.tape.QuantumTape, name: str = "device"
) -> (Sequence[qml.tape.QuantumTape], Callable):
) -> tuple[Sequence[qml.tape.QuantumTape], Callable]:
"""Raises an error if the tape has finite shots.

Args:
Expand All @@ -104,7 +106,7 @@ def no_sampling(
@transform
def validate_device_wires(
tape: qml.tape.QuantumTape, wires: Optional[qml.wires.Wires] = None, name: str = "device"
) -> (Sequence[qml.tape.QuantumTape], Callable):
) -> tuple[Sequence[qml.tape.QuantumTape], Callable]:
"""Validates that all wires present in the tape are in the set of provided wires. Adds the
device wires to measurement processes like :class:`~.measurements.StateMP` that are broadcasted
across all available wires.
Expand Down Expand Up @@ -145,23 +147,29 @@ def validate_device_wires(

@transform
def mid_circuit_measurements(
tape: qml.tape.QuantumTape, device
) -> (Sequence[qml.tape.QuantumTape], Callable):
tape: qml.tape.QuantumTape, device, mcm_config=MCMConfig()
) -> tuple[Sequence[qml.tape.QuantumTape], Callable]:
"""Provide the transform to handle mid-circuit measurements.

If the tape or device uses finite-shot, use the native implementation (i.e. no transform),
and use the ``qml.defer_measurements`` transform otherwise.
"""

if tape.shots:
if isinstance(mcm_config, dict):
mcm_config = MCMConfig(**mcm_config)
mcm_method = mcm_config.mcm_method
if mcm_method is None:
mcm_method = "one-shot" if tape.shots else "deferred"

if mcm_method == "one-shot":
return qml.dynamic_one_shot(tape)
return qml.defer_measurements(tape, device=device)


@transform
def validate_multiprocessing_workers(
tape: qml.tape.QuantumTape, max_workers: int, device
) -> (Sequence[qml.tape.QuantumTape], Callable):
) -> tuple[Sequence[qml.tape.QuantumTape], Callable]:
"""Validates the number of workers for multiprocessing.

Checks that the CPU is not oversubscribed and warns user if it is,
Expand Down Expand Up @@ -220,7 +228,7 @@ def validate_multiprocessing_workers(
@transform
def validate_adjoint_trainable_params(
tape: qml.tape.QuantumTape,
) -> (Sequence[qml.tape.QuantumTape], Callable):
) -> tuple[Sequence[qml.tape.QuantumTape], Callable]:
"""Raises a warning if any of the observables is trainable, and raises an error if any
trainable parameters belong to state-prep operations. Can be used in validating circuits
for adjoint differentiation.
Expand Down Expand Up @@ -256,7 +264,7 @@ def decompose(
max_expansion: Union[int, None] = None,
name: str = "device",
error: Exception = DeviceError,
) -> (Sequence[qml.tape.QuantumTape], Callable):
) -> tuple[Sequence[qml.tape.QuantumTape], Callable]:
"""Decompose operations until the stopping condition is met.

Args:
Expand Down Expand Up @@ -370,7 +378,7 @@ def validate_observables(
tape: qml.tape.QuantumTape,
stopping_condition: Callable[[qml.operation.Operator], bool],
name: str = "device",
) -> (Sequence[qml.tape.QuantumTape], Callable):
) -> tuple[Sequence[qml.tape.QuantumTape], Callable]:
"""Validates the observables and measurements for a circuit.

Args:
Expand Down Expand Up @@ -412,7 +420,7 @@ def validate_observables(
@transform
def validate_measurements(
tape: qml.tape.QuantumTape, analytic_measurements=None, sample_measurements=None, name="device"
) -> (Sequence[qml.tape.QuantumTape], Callable):
) -> tuple[Sequence[qml.tape.QuantumTape], Callable]:
"""Validates the supported state and sample based measurement processes.

Args:
Expand Down
Loading
Loading