diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index 1dd17edd184..e837d5603f7 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -263,6 +263,8 @@ outcome of such mid-circuit measurements: qml.cond(m_0, qml.RY)(y, wires=0) return qml.probs(wires=[0]) +.. _deferred_measurements: + Deferred measurements ********************* @@ -312,6 +314,8 @@ tensor([0.90165331, 0.09834669], requires_grad=True) and quantum hardware. The one-shot transform below does not have this limitation, but has computational cost that scales with the number of shots used. +.. _one_shot_transform: + The one-shot transform ********************** @@ -524,6 +528,87 @@ Collecting statistics for sequences of mid-circuit measurements is supported wit When collecting statistics for a list of mid-circuit measurements, values manipulated using arithmetic operators should not be used as this behaviour is not supported. +Configuring mid-circuit measurements +************************************ + +As seen above, there are multiple ways in which circuits with mid-circuit measurements can be executed with +PennyLane. For ease of use, we provide the following configuration options to users when initializing a +:class:`~pennylane.QNode`: + +* ``mcm_method``: To set the method used for applying mid-circuit measurements. Use ``mcm_method="deferred"`` + to use the :ref:`deferred measurements principle ` or ``mcm_method="one-shot"`` to use + the :ref:`one-shot transform `. When executing with finite shots, ``mcm_method="one-shot"`` + will be the default, and ``mcm_method="deferred"`` otherwise. + + .. warning:: + + If the ``mcm_method`` argument is provided, the :func:`~pennylane.defer_measurements` or + :func:`~pennylane.dynamic_one_shot` transforms must not be applied directly to the :class:`~pennylane.QNode` + as it can lead to incorrect behaviour. + +* ``postselect_mode``: To configure how invalid shots are handled when postselecting mid-circuit measurements + with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid samples. In this case, the number + of samples that are used for processing results can be less than the total number of shots. If + ``postselect_mode="fill-shots"`` is used, then the postselected value will be sampled unconditionally, and all + samples will be valid. This is equivalent to sampling until the number of valid samples matches the total number + of shots. The default behaviour is ``postselect_mode="hw-like"``. + + .. code-block:: python3 + + import pennylane as qml + import numpy as np + + dev = qml.device("default.qubit", wires=3, shots=10) + + def circuit(x): + qml.RX(x, 0) + m0 = qml.measure(0, postselect=1) + qml.CNOT([0, 1]) + return qml.sample(qml.PauliZ(0)) + + fill_shots_qnode = qml.QNode(circuit, dev, mcm_method="one-shot", postselect_mode="fill-shots") + hw_like_qnode = qml.QNode(circuit, dev, mcm_method="one-shot", postselect_mode="hw-like") + + >>> fill_shots_qnode(np.pi / 2) + array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]) + >>> hw_like_qnode(np.pi / 2) + array([-1., -1., -1., -1., -1., -1., -1.]) + + .. note:: + + When using the ``jax`` interface, ``postselect_mode="hw-like"`` will have different behaviour based on the + chosen ``mcm_method``. + + * If ``mcm_method="one-shot"``, invalid shots will not be discarded. Instead, invalid samples will be replaced + by ``np.iinfo(np.int32).min``. These invalid samples will not be used for processing final results (like + expectation values), but will appear in the ``QNode`` output if samples are requested directly. Consider + the circuit below: + + .. code-block:: python3 + + import pennylane as qml + import jax + import jax.numpy as jnp + + dev = qml.device("default.qubit", wires=3, shots=10, seed=jax.random.PRNGKey(123)) + + @qml.qnode(dev, postselect_mode="hw-like", mcm_method="one-shot") + def circuit(x): + qml.RX(x, 0) + qml.measure(0, postselect=1) + return qml.sample(qml.PauliZ(0)) + + >>> x = jnp.array(1.8) + >>> f(x) + Array([-2.1474836e+09, -1.0000000e+00, -2.1474836e+09, -2.1474836e+09, + -1.0000000e+00, -2.1474836e+09, -1.0000000e+00, -2.1474836e+09, + -1.0000000e+00, -1.0000000e+00], dtype=float32, weak_type=True) + + * When using ``jax.jit``, using ``mcm_method="deferred"`` is not supported with ``postselect_mode="hw-like"`` and + an error will be raised if this configuration is requested. This is due to limitations of the + :func:`~pennylane.defer_measurements` transform, and this behaviour will change in the future to be more + consistent with ``mcm_method="one-shot"``. + Changing the number of shots ---------------------------- diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 8dd20c57e4b..5d0098d5ab7 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,6 +4,19 @@

New features since last release

+* `qml.QNode` and `qml.qnode` now accept two new keyword arguments: `postselect_mode` and `mcm_method`. + These keyword arguments can be used to configure how the device should behave when running circuits with + mid-circuit measurements. + [(#5679)](https://github.com/PennyLaneAI/pennylane/pull/5679) + + * `postselect_mode="hw-like"` will indicate to devices to discard invalid shots when postselecting + mid-circuit measurements. Use `postselect_mode="fill-shots"` to unconditionally sample the postselected + value, thus making all samples valid. This is equivalent to sampling until the number of valid samples + matches the total number of shots. + * `mcm_method` will indicate which strategy to use for running circuits with mid-circuit measurements. + Use `mcm_method="deferred"` to use the deferred measurements principle, or `mcm_method="one-shot"` + to execute once for each shot. + * The `default.tensor` device is introduced to perform tensor network simulation of a quantum circuit. [(#5699)](https://github.com/PennyLaneAI/pennylane/pull/5699) @@ -40,7 +53,7 @@ * The `dynamic_one_shot` transform can be compiled with `jax.jit`. [(#5557)](https://github.com/PennyLaneAI/pennylane/pull/5557) - + * When using `defer_measurements` with postselecting mid-circuit measurements, operations that will never be active due to the postselected state are skipped in the transformed quantum circuit. In addition, postselected controls are skipped, as they are evaluated diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py index aa8cad8ca87..963683a4107 100644 --- a/pennylane/_qubit_device.py +++ b/pennylane/_qubit_device.py @@ -468,7 +468,7 @@ def _multi_meas_with_counts_shot_vec(self, circuit: QuantumTape, shot_tuple, r): return new_r - def batch_execute(self, circuits): + def batch_execute(self, circuits, **kwargs): """Execute a batch of quantum circuits on the device. The circuits are represented by tapes, and they are executed one-by-one using the @@ -492,13 +492,16 @@ def batch_execute(self, circuits): ), ) + if self.capabilities().get("supports_mid_measure", False): + kwargs.setdefault("postselect_mode", None) + results = [] for circuit in circuits: # we need to reset the device here, else it will # not start the next computation in the zero state self.reset() - res = self.execute(circuit) + res = self.execute(circuit, **kwargs) results.append(res) if self.tracker.active: diff --git a/pennylane/capture/capture_qnode.py b/pennylane/capture/capture_qnode.py index bbcd731e934..41525355214 100644 --- a/pennylane/capture/capture_qnode.py +++ b/pennylane/capture/capture_qnode.py @@ -15,6 +15,7 @@ This submodule defines a capture compatible call to QNodes. """ +from copy import copy from functools import lru_cache, partial import pennylane as qml @@ -159,7 +160,9 @@ def f(x): qfunc = partial(qnode.func, **kwargs) if kwargs else qnode.func qfunc_jaxpr = jax.make_jaxpr(qfunc)(*args) - qnode_kwargs = {"diff_method": qnode.diff_method, **qnode.execute_kwargs} + execute_kwargs = copy(qnode.execute_kwargs) + mcm_config = execute_kwargs.pop("mcm_config") + qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config} qnode_prim = _get_qnode_prim() return qnode_prim.bind( diff --git a/pennylane/devices/__init__.py b/pennylane/devices/__init__.py index 2c6ac9cf8ad..7757f321381 100644 --- a/pennylane/devices/__init__.py +++ b/pennylane/devices/__init__.py @@ -54,6 +54,7 @@ :toctree: api ExecutionConfig + MCMConfig Device DefaultQubit NullQubit @@ -146,7 +147,7 @@ def execute(self, circuits, execution_config = qml.devices.DefaultExecutionConfi """ -from .execution_config import ExecutionConfig, DefaultExecutionConfig +from .execution_config import ExecutionConfig, DefaultExecutionConfig, MCMConfig from .device_api import Device from .default_qubit import DefaultQubit diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 35dc391a6a5..c0a8a99f48d 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -501,7 +501,9 @@ def preprocess( transform_program = TransformProgram() transform_program.add_transform(validate_device_wires, self.wires, name=self.name) - transform_program.add_transform(mid_circuit_measurements, device=self) + transform_program.add_transform( + mid_circuit_measurements, device=self, mcm_config=config.mcm_config + ) transform_program.add_transform( decompose, stopping_condition=stopping_condition, @@ -597,6 +599,7 @@ def execute( "interface": interface, "state_cache": self._state_cache, "prng_key": _key, + "postselect_mode": execution_config.mcm_config.postselect_mode, }, ) for c, _key in zip(circuits, prng_keys) @@ -604,7 +607,14 @@ def execute( vanilla_circuits = convert_to_numpy_parameters(circuits)[0] seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits)) - simulate_kwargs = [{"rng": _rng, "prng_key": _key} for _rng, _key in zip(seeds, prng_keys)] + simulate_kwargs = [ + { + "rng": _rng, + "prng_key": _key, + "postselect_mode": execution_config.mcm_config.postselect_mode, + } + for _rng, _key in zip(seeds, prng_keys) + ] with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: exec_map = executor.map(_simulate_wrapper, vanilla_circuits, simulate_kwargs) @@ -848,6 +858,7 @@ def _simulate_wrapper(circuit, kwargs): def _adjoint_jac_wrapper(c, debugger=None): + c = c.map_to_standard_wires() state, is_state_batched = get_final_state(c, debugger=debugger) jac = adjoint_jacobian(c, state=state) res = measure_final_state(c, state, is_state_batched) @@ -855,6 +866,7 @@ def _adjoint_jac_wrapper(c, debugger=None): def _adjoint_jvp_wrapper(c, t, debugger=None): + c = c.map_to_standard_wires() state, is_state_batched = get_final_state(c, debugger=debugger) jvp = adjoint_jvp(c, t, state=state) res = measure_final_state(c, state, is_state_batched) @@ -862,6 +874,7 @@ def _adjoint_jvp_wrapper(c, t, debugger=None): def _adjoint_vjp_wrapper(c, t, debugger=None): + c = c.map_to_standard_wires() state, is_state_batched = get_final_state(c, debugger=debugger) vjp = adjoint_vjp(c, t, state=state) res = measure_final_state(c, state, is_state_batched) diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index e7ac20179a7..083f2880b6f 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -15,11 +15,39 @@ Contains the :class:`ExecutionConfig` data class. """ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union from pennylane.workflow import SUPPORTED_INTERFACES +@dataclass +class MCMConfig: + """A class to store mid-circuit measurement configurations.""" + + mcm_method: Optional[str] = None + """Which mid-circuit measurement strategy to use. Use ``deferred`` for the deferred + measurements principle and "one-shot" if using finite shots to execute the circuit + for each shot separately. If not specified, the device will decide which method to + use.""" + + postselect_mode: Optional[str] = None + """Configuration for handling shots with mid-circuit measurement postselection. If + ``"hw-like"``, invalid shots will be discarded and only results for valid shots will + be returned. If ``"fill-shots"``, results corresponding to the original number of + shots will be returned. If not specified, the device will decide which mode to use.""" + + def __post_init__(self): + """ + Validate the configured mid-circuit measurement options. + + Note that this hook is automatically called after init via the dataclass integration. + """ + if self.mcm_method not in ("deferred", "one-shot", None): + raise ValueError(f"Invalid mid-circuit measurements method '{self.mcm_method}'.") + if self.postselect_mode not in ("hw-like", "fill-shots", None): + raise ValueError(f"Invalid postselection mode '{self.postselect_mode}'.") + + # pylint: disable=too-many-instance-attributes @dataclass class ExecutionConfig: @@ -67,6 +95,9 @@ class ExecutionConfig: derivative_order: int = 1 """The derivative order to compute while evaluating a gradient""" + mcm_config: Union[MCMConfig, dict] = MCMConfig() + """Configuration options for handling mid-circuit measurements""" + def __post_init__(self): """ Validate the configured execution options. @@ -89,5 +120,10 @@ def __post_init__(self): if self.gradient_keyword_arguments is None: self.gradient_keyword_arguments = {} + if isinstance(self.mcm_config, dict): + self.mcm_config = MCMConfig(**self.mcm_config) + elif not isinstance(self.mcm_config, MCMConfig): + raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'") + DefaultExecutionConfig = ExecutionConfig() diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 882ed8445a1..23fed7614e6 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -28,6 +28,8 @@ from pennylane.typing import Result, ResultBatch from pennylane.wires import WireError +from .execution_config import MCMConfig + PostprocessingFn = Callable[[ResultBatch], Union[Result, ResultBatch]] @@ -80,7 +82,7 @@ def _operator_decomposition_gen( @transform def no_sampling( tape: qml.tape.QuantumTape, name: str = "device" -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Raises an error if the tape has finite shots. Args: @@ -104,7 +106,7 @@ def no_sampling( @transform def validate_device_wires( tape: qml.tape.QuantumTape, wires: Optional[qml.wires.Wires] = None, name: str = "device" -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Validates that all wires present in the tape are in the set of provided wires. Adds the device wires to measurement processes like :class:`~.measurements.StateMP` that are broadcasted across all available wires. @@ -145,15 +147,21 @@ def validate_device_wires( @transform def mid_circuit_measurements( - tape: qml.tape.QuantumTape, device -) -> (Sequence[qml.tape.QuantumTape], Callable): + tape: qml.tape.QuantumTape, device, mcm_config=MCMConfig() +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Provide the transform to handle mid-circuit measurements. If the tape or device uses finite-shot, use the native implementation (i.e. no transform), and use the ``qml.defer_measurements`` transform otherwise. """ - if tape.shots: + if isinstance(mcm_config, dict): + mcm_config = MCMConfig(**mcm_config) + mcm_method = mcm_config.mcm_method + if mcm_method is None: + mcm_method = "one-shot" if tape.shots else "deferred" + + if mcm_method == "one-shot": return qml.dynamic_one_shot(tape) return qml.defer_measurements(tape, device=device) @@ -161,7 +169,7 @@ def mid_circuit_measurements( @transform def validate_multiprocessing_workers( tape: qml.tape.QuantumTape, max_workers: int, device -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Validates the number of workers for multiprocessing. Checks that the CPU is not oversubscribed and warns user if it is, @@ -220,7 +228,7 @@ def validate_multiprocessing_workers( @transform def validate_adjoint_trainable_params( tape: qml.tape.QuantumTape, -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Raises a warning if any of the observables is trainable, and raises an error if any trainable parameters belong to state-prep operations. Can be used in validating circuits for adjoint differentiation. @@ -256,7 +264,7 @@ def decompose( max_expansion: Union[int, None] = None, name: str = "device", error: Exception = DeviceError, -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Decompose operations until the stopping condition is met. Args: @@ -370,7 +378,7 @@ def validate_observables( tape: qml.tape.QuantumTape, stopping_condition: Callable[[qml.operation.Operator], bool], name: str = "device", -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Validates the observables and measurements for a circuit. Args: @@ -412,7 +420,7 @@ def validate_observables( @transform def validate_measurements( tape: qml.tape.QuantumTape, analytic_measurements=None, sample_measurements=None, name="device" -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Validates the supported state and sample based measurement processes. Args: diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index ae0c9117c38..e5a3ed5464f 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -165,7 +165,18 @@ def apply_operation( is_state_batched (bool): Boolean representing whether the state is batched or not debugger (_Debugger): The debugger to use **execution_kwargs (Optional[dict]): Optional keyword arguments needed for applying - some operations + some operations described below. + + Keyword Arguments: + mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value + interface (str): The machine learning interface of the state + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to + keep the same number of shots. ``None`` by default. + rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. + prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is + the key to the JAX pseudo random number generator. Only for simulation using JAX. + If None, a ``numpy.random.default_rng`` will be used for sampling. Returns: ndarray: output state @@ -233,6 +244,11 @@ def apply_conditional( is_state_batched (bool): Boolean representing whether the state is batched or not debugger (_Debugger): The debugger to use mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value + interface (str): The machine learning interface of the state + rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. + prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is + the key to the JAX pseudo random number generator. Only for simulation using JAX. + If None, a ``numpy.random.default_rng`` will be used for sampling. Returns: ndarray: output state @@ -284,10 +300,13 @@ def apply_mid_measure( is_state_batched (bool): Boolean representing whether the state is batched or not debugger (_Debugger): The debugger to use mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to + keep the same number of shots. ``None`` by default. rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. - If None, a ``numpy.random.default_rng`` will be for sampling. + If None, a ``numpy.random.default_rng`` will be used for sampling. Returns: ndarray: output state @@ -295,24 +314,31 @@ def apply_mid_measure( mid_measurements = execution_kwargs.get("mid_measurements", None) rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) + postselect_mode = execution_kwargs.get("postselect_mode", None) + if is_state_batched: raise ValueError("MidMeasureMP cannot be applied to batched states.") wire = op.wires - axis = wire.toarray()[0] - slices = [slice(None)] * qml.math.ndim(state) - slices[axis] = 0 - prob0 = qml.math.norm(state[tuple(slices)]) ** 2 interface = qml.math.get_deep_interface(state) - if prng_key is not None: - # pylint: disable=import-outside-toplevel - from jax.random import binomial - - def binomial_fn(n, p): - return binomial(prng_key, n, p).astype(int) + if postselect_mode == "fill-shots" and op.postselect is not None: + sample = op.postselect else: - binomial_fn = np.random.binomial if rng is None else rng.binomial - sample = binomial_fn(1, 1 - prob0) + axis = wire.toarray()[0] + slices = [slice(None)] * qml.math.ndim(state) + slices[axis] = 0 + prob0 = qml.math.norm(state[tuple(slices)]) ** 2 + + if prng_key is not None: + # pylint: disable=import-outside-toplevel + from jax.random import binomial + + def binomial_fn(n, p): + return binomial(prng_key, n, p).astype(int) + + else: + binomial_fn = np.random.binomial if rng is None else rng.binomial + sample = binomial_fn(1, 1 - prob0) mid_measurements[op] = sample # Using apply_operation(qml.QubitUnitary,...) instead of apply_operation(qml.Projector([sample], wire),...) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 47e8c85f36d..b862d590e19 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -63,13 +63,15 @@ def __init__(self, shots=None): if isinstance(shots, int): self.total_shots = shots self.shot_vector = (qml.measurements.ShotCopies(shots, 1),) + elif isinstance(shots, self.__class__): + return # self already _is_ shots as defined by __new__ else: self.__all_tuple_init__([s if isinstance(s, tuple) else (s, 1) for s in shots]) self._frozen = True -def _postselection_postprocess(state, is_state_batched, shots, rng=None, prng_key=None): +def _postselection_postprocess(state, is_state_batched, shots, **execution_kwargs): """Update state after projector is applied.""" if is_state_batched: raise ValueError( @@ -79,6 +81,10 @@ def _postselection_postprocess(state, is_state_batched, shots, rng=None, prng_ke "postselection is used." ) + rng = execution_kwargs.get("rng", None) + prng_key = execution_kwargs.get("prng_key", None) + postselect_mode = execution_kwargs.get("postselect_mode", None) + # The floor function is being used here so that a norm very close to zero becomes exactly # equal to zero so that the state can become invalid. This way, execution can continue, and # bad postselection gives results that are invalid rather than results that look valid but @@ -100,9 +106,9 @@ def _postselection_postprocess(state, is_state_batched, shots, rng=None, prng_ke binomial_fn = np.random.binomial if rng is None else rng.binomial postselected_shots = ( - [int(binomial_fn(s, float(norm**2))) for s in shots] - if not qml.math.is_abstract(norm) - else shots + shots + if postselect_mode == "fill-shots" or qml.math.is_abstract(norm) + else [int(binomial_fn(s, float(norm**2))) for s in shots] ) # _FlexShots is used here since the binomial distribution could result in zero @@ -121,25 +127,26 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): This is an internal function that will be called by the successor to ``default.qubit``. Args: - circuit (.QuantumScript): The single circuit to simulate + circuit (.QuantumScript): The single circuit to simulate. This circuit is assumed to have + non-negative integer wire labels debugger (._Debugger): The debugger to use interface (str): The machine learning interface to create the initial state with mid_measurements (None, dict): Dictionary of mid-circuit measurements rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. - If None, a ``numpy.random.default_rng`` will be for sampling. + If None, a ``numpy.random.default_rng`` will be used for sampling. + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to + keep the same number of shots. Default is ``None``. Returns: Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and whether the state has a batch dimension. """ - rng = execution_kwargs.get("rng", None) - prng_key = execution_kwargs.get("prng_key", None) + prng_key = execution_kwargs.pop("prng_key", None) interface = execution_kwargs.get("interface", None) - mid_measurements = execution_kwargs.get("mid_measurements", None) - circuit = circuit.map_to_standard_wires() prep = None if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): @@ -159,16 +166,16 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): state, is_state_batched=is_state_batched, debugger=debugger, - mid_measurements=mid_measurements, - rng=rng, prng_key=key, + **execution_kwargs, ) # Handle postselection on mid-circuit measurements if isinstance(op, qml.Projector): prng_key, key = jax_random_split(prng_key) - state, circuit._shots = _postselection_postprocess( - state, is_state_batched, circuit.shots, rng=rng, prng_key=key + state, new_shots = _postselection_postprocess( + state, is_state_batched, circuit.shots, prng_key=key, **execution_kwargs ) + circuit._shots = new_shots # new state is batched if i) the old state is batched, or ii) the new op adds a batch dim is_state_batched = is_state_batched or (op.batch_size is not None) @@ -190,7 +197,8 @@ def measure_final_state(circuit, state, is_state_batched, **execution_kwargs) -> This is an internal function that will be called by the successor to ``default.qubit``. Args: - circuit (.QuantumScript): The single circuit to simulate + circuit (.QuantumScript): The single circuit to simulate. This circuit is assumed to have + non-negative integer wire labels state (TensorLike): The state to perform measurement on is_state_batched (bool): Whether the state has a batch dimension or not. rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A @@ -199,7 +207,7 @@ def measure_final_state(circuit, state, is_state_batched, **execution_kwargs) -> prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. If None, the default ``sample_state`` function and a ``numpy.random.default_rng`` - will be for sampling. + will be used for sampling. mid_measurements (None, dict): Dictionary of mid-circuit measurements Returns: @@ -209,8 +217,6 @@ def measure_final_state(circuit, state, is_state_batched, **execution_kwargs) -> prng_key = execution_kwargs.get("prng_key", None) mid_measurements = execution_kwargs.get("mid_measurements", None) - circuit = circuit.map_to_standard_wires() - # analytic case if not circuit.shots: @@ -268,6 +274,9 @@ def simulate( the key to the JAX pseudo random number generator. If None, a random key will be generated. Only for simulation using JAX. interface (str): The machine learning interface to create the initial state with + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to + keep the same number of shots. Default is ``None``. Returns: tuple(TensorLike): The results of the simulation @@ -282,9 +291,8 @@ def simulate( tensor([0.68117888, 0. , 0.31882112, 0. ], requires_grad=True)) """ - rng = execution_kwargs.get("rng", None) - prng_key = execution_kwargs.get("prng_key", None) - interface = execution_kwargs.get("interface", None) + prng_key = execution_kwargs.pop("prng_key", None) + circuit = circuit.map_to_standard_wires() has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations) if circuit.shots and has_mcm: @@ -302,7 +310,7 @@ def simulate( def simulate_partial(k): return simulate_one_shot_native_mcm( - aux_circ, debugger=debugger, rng=rng, prng_key=k, interface=interface + aux_circ, debugger=debugger, prng_key=k, **execution_kwargs ) results = jax.vmap(simulate_partial, in_axes=(0,))(keys) @@ -311,18 +319,20 @@ def simulate_partial(k): for i in range(circuit.shots.total_shots): results.append( simulate_one_shot_native_mcm( - aux_circ, debugger=debugger, rng=rng, prng_key=keys[i], interface=interface + aux_circ, debugger=debugger, prng_key=keys[i], **execution_kwargs ) ) return tuple(results) ops_key, meas_key = jax_random_split(prng_key) state, is_state_batched = get_final_state( - circuit, debugger=debugger, rng=rng, prng_key=ops_key, interface=interface + circuit, debugger=debugger, prng_key=ops_key, **execution_kwargs ) if state_cache is not None: state_cache[circuit.hash] = state - return measure_final_state(circuit, state, is_state_batched, rng=rng, prng_key=meas_key) + return measure_final_state( + circuit, state, is_state_batched, prng_key=meas_key, **execution_kwargs + ) @debug_logger @@ -339,30 +349,30 @@ def simulate_one_shot_native_mcm( the key to the JAX pseudo random number generator. If None, a random key will be generated. Only for simulation using JAX. interface (str): The machine learning interface to create the initial state with + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to + keep the same number of shots. Default is ``None``. Returns: tuple(TensorLike): The results of the simulation dict: The mid-circuit measurement results of the simulation """ - rng = execution_kwargs.get("rng", None) - prng_key = execution_kwargs.get("prng_key", None) - interface = execution_kwargs.get("interface", None) + prng_key = execution_kwargs.pop("prng_key", None) ops_key, meas_key = jax_random_split(prng_key) mid_measurements = {} state, is_state_batched = get_final_state( circuit, debugger=debugger, - interface=interface, mid_measurements=mid_measurements, - rng=rng, prng_key=ops_key, + **execution_kwargs, ) return measure_final_state( circuit, state, is_state_batched, - rng=rng, prng_key=meas_key, mid_measurements=mid_measurements, + **execution_kwargs, ) diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index d1c6f05a5f4..69de5f88a2c 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -104,7 +104,7 @@ def null_postprocessing(results): @transform def defer_measurements( tape: QuantumTape, reduce_postselected: bool = True, **kwargs -) -> (Sequence[QuantumTape], Callable): +) -> tuple[Sequence[QuantumTape], Callable]: """Quantum function transform that substitutes operations conditioned on measurement outcomes to controlled operations. diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 998cefa93c5..831e30031af 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -33,6 +33,7 @@ SampleMP, VarianceMP, ) +from pennylane.typing import TensorLike from .core import transform @@ -49,7 +50,7 @@ def null_postprocessing(results): @transform def dynamic_one_shot( tape: qml.tape.QuantumTape, **kwargs -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Transform a QNode to into several one-shot tapes to support dynamic circuit execution. Args: @@ -207,14 +208,14 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): def parse_native_mid_circuit_measurements( - circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results + circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results: TensorLike ): """Combines, gathers and normalizes the results of native mid-circuit measurement runs. Args: - circuit (QuantumTape): A one-shot (auxiliary) QuantumScript - all_shot_meas (Sequence[Any]): List of accumulated measurement results - mcm_shot_meas (Sequence[dict]): List of dictionaries containing the mid-circuit measurement results of each shot + circuit (QuantumTape): Initial ``QuantumScript`` + aux_tapes (List[QuantumTape]): List of auxilary ``QuantumScript`` objects + results (TensorLike): Array of measurement results Returns: tuple(TensorLike): The results of the simulation @@ -287,43 +288,42 @@ def measurement_with_no_shots(measurement): return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0] -def gather_non_mcm(circuit_measurement, measurement, is_valid): +def gather_non_mcm(measurement, samples, is_valid): """Combines, gathers and normalizes several measurements with trivial measurement values. Args: - circuit_measurement (MeasurementProcess): measurement - measurement (TensorLike): measurement results - samples (List[dict]): Mid-circuit measurement samples + measurement (MeasurementProcess): measurement + samples (TensorLike): Post-processed measurement samples + is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at + each index specifies whether or not the respective sample is valid. Returns: TensorLike: The combined measurement outcome """ - if isinstance(circuit_measurement, CountsMP): + if isinstance(measurement, CountsMP): tmp = Counter() - for i, d in enumerate(measurement): + for i, d in enumerate(samples): tmp.update( dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items()) ) tmp = Counter({k: v for k, v in tmp.items() if v > 0}) return dict(sorted(tmp.items())) - if isinstance(circuit_measurement, ExpectationMP): - return qml.math.sum(measurement * is_valid) / qml.math.sum(is_valid) - if isinstance(circuit_measurement, ProbabilityMP): - return qml.math.sum(measurement * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum( - is_valid - ) - if isinstance(circuit_measurement, SampleMP): + if isinstance(measurement, ExpectationMP): + return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid) + if isinstance(measurement, ProbabilityMP): + return qml.math.sum(samples * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum(is_valid) + if isinstance(measurement, SampleMP): is_interface_jax = qml.math.get_deep_interface(is_valid) == "jax" - if is_interface_jax and measurement.ndim == 2: + if is_interface_jax and samples.ndim == 2: is_valid = is_valid.reshape((-1, 1)) return ( - qml.math.where(is_valid, measurement, fill_in_value) + qml.math.where(is_valid, samples, fill_in_value) if is_interface_jax - else measurement[is_valid] + else samples[is_valid] ) # VarianceMP - expval = qml.math.sum(measurement * is_valid) / qml.math.sum(is_valid) - return qml.math.sum((measurement - expval) ** 2 * is_valid) / qml.math.sum(is_valid) + expval = qml.math.sum(samples * is_valid) / qml.math.sum(is_valid) + return qml.math.sum((samples - expval) ** 2 * is_valid) / qml.math.sum(is_valid) def gather_mcm(measurement, samples, is_valid): @@ -332,6 +332,8 @@ def gather_mcm(measurement, samples, is_valid): Args: measurement (MeasurementProcess): measurement samples (List[dict]): Mid-circuit measurement samples + is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at + each index specifies whether or not the respective sample is valid. Returns: TensorLike: The combined measurement outcome diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 4348e5f783b..f27d100e752 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -267,7 +267,18 @@ def _make_inner_execute( """ if isinstance(device, qml.devices.LegacyDevice): - device_execution = set_shots(device, override_shots)(device.batch_execute) + dev_execute = ( + device.batch_execute + # If this condition is not met, then dev.batch_execute likely also doesn't include + # any kwargs in its signature, hence why we use partial conditionally + if execution_config is None + or not device.capabilities().get("supports_mid_measure", False) + else partial( + device.batch_execute, + postselect_mode=execution_config.mcm_config.postselect_mode, + ) + ) + device_execution = set_shots(device, override_shots)(dev_execute) else: device_execution = partial(device.execute, execution_config=execution_config) @@ -360,6 +371,36 @@ def execution_function_with_caching(tapes): return execution_function_with_caching +def _get_interface_name(tapes, interface): + """Helper function to get the interface name of a list of tapes + + Args: + tapes (list[.QuantumScript]): Quantum tapes + interface (Optional[str]): Original interface to use as reference. + + Returns: + str: Interface name""" + if interface == "auto": + params = [] + for tape in tapes: + params.extend(tape.get_parameters(trainable_only=False)) + interface = qml.math.get_interface(*params) + if INTERFACE_MAP.get(interface, "") == "tf" and _use_tensorflow_autograph(): + interface = "tf-autograph" + if interface == "jax": + try: # pragma: no cover + from .interfaces.jax import get_jax_interface_name + except ImportError as e: # pragma: no cover + raise qml.QuantumFunctionError( # pragma: no cover + "jax not found. Please install the latest " # pragma: no cover + "version of jax to enable the 'jax' interface." # pragma: no cover + ) from e # pragma: no cover + + interface = get_jax_interface_name(tapes) + + return interface + + def execute( tapes: Sequence[QuantumTape], device: device_type, @@ -377,6 +418,7 @@ def execute( max_expansion=10, device_batch_transform=True, device_vjp=False, + mcm_config=None, ) -> ResultBatch: """New function to execute a batch of tapes on a device in an autodifferentiable-compatible manner. More cases will be added, during the project. The current version is supporting forward execution for NumPy and does not support shot vectors. @@ -423,6 +465,7 @@ def execute( constituent terms if not supported on the device. device_vjp=False (Optional[bool]): whether or not to use the device provided jacobian product if it is available. + mcm_config (dict): Dictionary containing configuration options for handling mid-circuit measurements. Returns: list[tensor_like[float]]: A nested list of tape results. Each element in @@ -512,26 +555,10 @@ def cost_fn(params, x): ### Specifying and preprocessing variables #### - if interface == "auto": - params = [] - for tape in tapes: - params.extend(tape.get_parameters(trainable_only=False)) - interface = qml.math.get_interface(*params) - if INTERFACE_MAP.get(interface, "") == "tf" and _use_tensorflow_autograph(): - interface = "tf-autograph" - if interface == "jax": - try: # pragma: no-cover - from .interfaces.jax import get_jax_interface_name - except ImportError as e: # pragma: no-cover - raise qml.QuantumFunctionError( # pragma: no-cover - "jax not found. Please install the latest " # pragma: no-cover - "version of jax to enable the 'jax' interface." # pragma: no-cover - ) from e # pragma: no-cover - - interface = get_jax_interface_name(tapes) - # Only need to calculate derivatives with jax when we know it will be executed later. - if interface in {"jax", "jax-jit"}: - grad_on_execution = grad_on_execution if isinstance(gradient_fn, Callable) else False + interface = _get_interface_name(tapes, interface) + # Only need to calculate derivatives with jax when we know it will be executed later. + if interface in {"jax", "jax-jit"}: + grad_on_execution = grad_on_execution if isinstance(gradient_fn, Callable) else False if ( device_vjp @@ -543,10 +570,20 @@ def cost_fn(params, x): ) gradient_kwargs = gradient_kwargs or {} + mcm_config = mcm_config or {} config = config or _get_execution_config( - gradient_fn, grad_on_execution, interface, device, device_vjp + gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) + # Mid-circuit measurement configuration validation + mcm_interface = _get_interface_name(tapes, "auto") if interface is None else interface + if mcm_interface == "jax-jit" and config.mcm_config.mcm_method == "deferred": + # This is a current limitation of defer_measurements. "hw-like" behaviour is + # not yet accessible. + if config.mcm_config.postselect_mode == "hw-like": + raise ValueError("Using postselect_mode='hw-like' is not supported with jax-jit.") + config.mcm_config.postselect_mode = "fill-shots" + if transform_program is None: if isinstance(device, qml.devices.Device): transform_program = device.preprocess(config)[0] @@ -790,7 +827,9 @@ def device_gradient_fn(inner_tapes, **gradient_kwargs): return post_processing(results) -def _get_execution_config(gradient_fn, grad_on_execution, interface, device, device_vjp): +def _get_execution_config( + gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config +): """Helper function to get the execution config.""" if gradient_fn is None: _gradient_method = None @@ -803,6 +842,7 @@ def _get_execution_config(gradient_fn, grad_on_execution, interface, device, dev gradient_method=_gradient_method, grad_on_execution=None if grad_on_execution == "best" else grad_on_execution, use_device_jacobian_product=device_vjp, + mcm_config=mcm_config, ) if isinstance(device, qml.devices.Device): _, config = device.preprocess(config) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 525f2d66263..e0fd75458c4 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -69,18 +69,20 @@ def _make_execution_config( _gradient_method = diff_method else: _gradient_method = "gradient-transform" - grad_on_execution = getattr(circuit, "execute_kwargs", {}).get("grad_on_execution") + execute_kwargs = getattr(circuit, "execute_kwargs", {}) + grad_on_execution = execute_kwargs.get("grad_on_execution") if getattr(circuit, "interface", "") == "jax": grad_on_execution = False elif grad_on_execution == "best": grad_on_execution = None + mcm_config = execute_kwargs.get("mcm_config", {}) + return qml.devices.ExecutionConfig( interface=getattr(circuit, "interface", None), gradient_method=_gradient_method, grad_on_execution=grad_on_execution, - use_device_jacobian_product=getattr(circuit, "execute_kwargs", {"device_vjp": False})[ - "device_vjp" - ], + use_device_jacobian_product=execute_kwargs.get("device_vjp", False), + mcm_config=mcm_config, ) @@ -217,6 +219,16 @@ class QNode: (classical) computational overhead during the backwards pass. device_vjp (bool): Whether or not to use the device-provided Vector Jacobian Product (VJP). A value of ``None`` indicates to use it if the device provides it, but use the full jacobian otherwise. + postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. If + ``"hw-like"``, invalid shots will be discarded and only results for valid shots will be returned. + If ``"fill-shots"``, results corresponding to the original number of shots will be returned. The + default is ``None``, in which case the device will automatically choose the best configuration. For + usage details, please refer to the :doc:`main measurements page `. + mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements. Use ``"deferred"`` + to apply the deferred measurements principle (using the :func:`~pennylane.defer_measurements` transform), + or ``"one-shot"`` if using finite shots to execute the circuit for each shot separately. If not provided, + the device will determine the best choice automatically. For usage details, please refer to the + :doc:`main measurements page `. Keyword Args: **kwargs: Any additional keyword arguments provided are passed to the differentiation @@ -440,6 +452,8 @@ def __init__( cachesize=10000, max_diff=1, device_vjp=False, + postselect_mode=None, + mcm_method=None, **gradient_kwargs, ): if logger.isEnabledFor(logging.DEBUG): @@ -512,6 +526,11 @@ def __init__( self.max_expansion = max_expansion cache = (max_diff > 1) if cache == "auto" else cache + if mcm_method not in ("deferred", "one-shot", None): + raise ValueError(f"Invalid mid-circuit measurements method '{mcm_method}'.") + if postselect_mode not in ("hw-like", "fill-shots", None): + raise ValueError(f"Invalid postselection mode '{postselect_mode}'.") + # execution keyword arguments self.execute_kwargs = { "grad_on_execution": grad_on_execution, @@ -520,6 +539,7 @@ def __init__( "max_diff": max_diff, "max_expansion": max_expansion, "device_vjp": device_vjp, + "mcm_config": {"postselect_mode": postselect_mode, "mcm_method": mcm_method}, } if self.expansion_strategy == "device": @@ -985,20 +1005,6 @@ def construct(self, args, kwargs): # pylint: disable=too-many-branches # check here only if enough wires raise qml.QuantumFunctionError(f"Operator {obj.name} must act on all wires") - # Apply the deferred measurement principle if the device doesn't - # support mid-circuit measurements natively. - # Only apply transform with old device API as postselection with - # broadcasting will split tapes. - expand_mid_measure = ( - any(isinstance(op, MidMeasureMP) for op in self.tape.operations) - and not isinstance(self.device, qml.devices.Device) - and not self.device.capabilities().get("supports_mid_measure", False) - ) - if expand_mid_measure or self.expansion_strategy == "device": - # Assume that tapes are not split if old device is used since postselection is not supported. - tapes, _ = qml.defer_measurements(self._tape, device=self.device) - self._tape = tapes[0] - if self.expansion_strategy == "device": if isinstance(self.device, qml.devices.Device): tape, _ = self.device.preprocess()[0]([self.tape]) @@ -1034,6 +1040,12 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml ) self._tape_cached = using_custom_cache and self.tape.hash in cache + finite_shots = _get_device_shots if override_shots is False else override_shots + if not finite_shots and self.execute_kwargs["mcm_config"]["mcm_method"] == "one-shot": + raise ValueError( + "Cannot use the 'one-shot' method for mid-circuit measurements with analytic mode." + ) + # Add the device program to the QNode program if isinstance(self.device, qml.devices.Device): config = _make_execution_config(self, self.gradient_fn) @@ -1042,14 +1054,24 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml else: config = None full_transform_program = qml.transforms.core.TransformProgram(self.transform_program) + has_mcm_support = ( any(isinstance(op, MidMeasureMP) for op in self._tape) and hasattr(self.device, "capabilities") and self.device.capabilities().get("supports_mid_measure", False) ) if has_mcm_support: - full_transform_program.add_transform(qml.dynamic_one_shot) + full_transform_program.add_transform( + qml.devices.preprocess.mid_circuit_measurements, + device=self.device, + mcm_config=self.execute_kwargs["mcm_config"], + ) override_shots = 1 + elif hasattr(self.device, "capabilities"): + full_transform_program.add_transform( + qml.defer_measurements, + device=self.device, + ) # Add the gradient expand to the program if necessary if getattr(self.gradient_fn, "expand_transform", False): diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 1fe57c65e23..ec5387647d2 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -121,6 +121,7 @@ def circuit(x): assert eqn0.params["shots"] == qml.measurements.Shots(None) expected_kwargs = {"diff_method": "best"} expected_kwargs.update(circuit.execute_kwargs) + expected_kwargs.update(expected_kwargs.pop("mcm_config")) assert eqn0.params["qnode_kwargs"] == expected_kwargs qfunc_jaxpr = eqn0.params["qfunc_jaxpr"] @@ -294,5 +295,7 @@ def circuit(): "max_diff": 2, "max_expansion": 10, "device_vjp": False, + "mcm_method": None, + "postselect_mode": None, } assert jaxpr.eqns[0].params["qnode_kwargs"] == expected diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py index 4af0d167ebb..4ee14716f77 100644 --- a/tests/devices/default_qubit/test_default_qubit.py +++ b/tests/devices/default_qubit/test_default_qubit.py @@ -18,7 +18,6 @@ import numpy as np import pytest -from flaky import flaky import pennylane as qml from pennylane.devices import DefaultQubit, ExecutionConfig @@ -1826,7 +1825,6 @@ def circ_expected(): assert qml.math.allclose(res, expected) assert qml.math.get_interface(res) == qml.math.get_interface(expected) - @flaky(max_runs=5) @pytest.mark.parametrize( "mp", [ @@ -1848,9 +1846,7 @@ def test_postselection_valid_finite_shots(self, param, mp, shots, interface, use if use_jit and (interface != "jax" or isinstance(shots, tuple)): pytest.skip("Cannot JIT in non-JAX interfaces, or with shot vectors.") - np.random.seed(42) - - dev = qml.device("default.qubit") + dev = qml.device("default.qubit", seed=1971) param = qml.math.asarray(param, like=interface) @qml.defer_measurements @@ -1888,7 +1884,12 @@ def circ_expected(): @pytest.mark.parametrize( "mp, expected_shape", - [(qml.sample(wires=[0]), (5,)), (qml.classical_shadow(wires=[0]), (2, 5, 1))], + [ + (qml.sample(wires=[0, 2]), (5, 2)), + (qml.classical_shadow(wires=[0, 2]), (2, 5, 2)), + (qml.sample(wires=[0]), (5,)), + (qml.classical_shadow(wires=[0]), (2, 5, 1)), + ], ) @pytest.mark.parametrize("param", np.linspace(np.pi / 4, 3 * np.pi / 4, 3)) @pytest.mark.parametrize("shots", [10, (10, 10)]) @@ -1918,11 +1919,6 @@ def circ_postselect(theta): qml.measure(0, postselect=1) return qml.apply(mp) - if use_jit: - import jax - - circ_postselect = jax.jit(circ_postselect, static_argnames=["shots"]) - res = circ_postselect(param, shots=shots) if not isinstance(shots, tuple): diff --git a/tests/devices/experimental/test_execution_config.py b/tests/devices/experimental/test_execution_config.py index ce751d9e2f2..361712c112e 100644 --- a/tests/devices/experimental/test_execution_config.py +++ b/tests/devices/experimental/test_execution_config.py @@ -17,7 +17,7 @@ import pytest -from pennylane.devices import ExecutionConfig +from pennylane.devices.execution_config import ExecutionConfig, MCMConfig def test_default_values(): @@ -30,6 +30,14 @@ def test_default_values(): assert config.gradient_keyword_arguments == {} assert config.grad_on_execution is None assert config.use_device_gradient is None + assert config.mcm_config == MCMConfig() + + +def test_mcm_config_default_values(): + """Test that the default values of MCMConfig are correct""" + mcm_config = MCMConfig() + assert mcm_config.postselect_mode is None + assert mcm_config.mcm_method is None def test_invalid_interface(): @@ -49,3 +57,36 @@ def test_invalid_grad_on_execution(): """Test invalid values for grad on execution raise an error.""" with pytest.raises(ValueError, match=r"grad_on_execution must be True, False,"): ExecutionConfig(grad_on_execution="forward") + + +@pytest.mark.parametrize( + "option", [MCMConfig(mcm_method="deferred"), {"mcm_method": "deferred"}, None] +) +def test_valid_execution_config_mcm_config(option): + """Test that the mcm_config attribute is set correctly""" + config = ExecutionConfig(mcm_config=option) if option else ExecutionConfig() + if option is None: + assert config.mcm_config == MCMConfig() + else: + assert config.mcm_config == MCMConfig(mcm_method="deferred") + + +def test_invalid_execution_config_mcm_config(): + """Test that an error is raised if mcm_config is set incorrectly""" + option = "foo" + with pytest.raises(ValueError, match="Got invalid type"): + _ = ExecutionConfig(mcm_config=option) + + +def test_mcm_config_invalid_mcm_method(): + """Test that an error is raised if creating MCMConfig with invalid mcm_method""" + option = "foo" + with pytest.raises(ValueError, match="Invalid mid-circuit measurements method"): + _ = MCMConfig(mcm_method=option) + + +def test_mcm_config_invalid_postselect_mode(): + """Test that an error is raised if creating MCMConfig with invalid postselect_mode""" + option = "foo" + with pytest.raises(ValueError, match="Invalid postselection mode"): + _ = MCMConfig(postselect_mode=option) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index ac54393a028..634c8fed2e5 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -18,6 +18,7 @@ import pennylane as qml from pennylane.devices.qubit import get_final_state, measure_final_state, simulate +from pennylane.devices.qubit.simulate import _FlexShots class TestCurrentlyUnsupportedCases: @@ -409,6 +410,30 @@ def test_nan_state(self, interface): assert qml.math.all(qml.math.isnan(res)) +class Test_FlexShots: + """Unit tests for _FlexShots""" + + @pytest.mark.parametrize( + "shots, expected_shot_vector", + [ + (0, (0,)), + ((10, 0, 5, 0), (10, 0, 5, 0)), + (((10, 3), (0, 5)), (10, 10, 10, 0, 0, 0, 0, 0)), + ], + ) + def test_init_with_zero_shots(self, shots, expected_shot_vector): + """Test that _FlexShots is initialized correctly with zero shots""" + flex_shots = _FlexShots(shots) + shot_vector = tuple(s for s in flex_shots) + assert shot_vector == expected_shot_vector + + def test_init_with_other_shots(self): + """Test that a new _FlexShots object is not created if the input is a _FlexShots object.""" + shots = _FlexShots(10) + new_shots = _FlexShots(shots) + assert new_shots is shots + + class TestDebugger: """Tests that the debugger works for a simple circuit""" diff --git a/tests/devices/test_default_qubit_legacy.py b/tests/devices/test_default_qubit_legacy.py index a5b4df81cec..d3059b8ab38 100644 --- a/tests/devices/test_default_qubit_legacy.py +++ b/tests/devices/test_default_qubit_legacy.py @@ -17,7 +17,6 @@ # pylint: disable=too-many-arguments,too-few-public-methods # pylint: disable=protected-access,cell-var-from-loop import cmath -import copy import math from functools import partial @@ -89,35 +88,6 @@ VARPHI = np.linspace(0.02, 1, 3) -def test_qnode_native_mcm(mocker): - """Tests that the legacy devices may support native MCM execution via the dynamic_one_shot transform.""" - - class MCMDevice(DefaultQubitLegacy): - def apply(self, *args, **kwargs): - for op in args[0]: - if isinstance(op, qml.measurements.MidMeasureMP): - kwargs["mid_measurements"][op] = 0 - - @classmethod - def capabilities(cls): - default_capabilities = copy.copy(DefaultQubitLegacy.capabilities()) - default_capabilities["supports_mid_measure"] = True - return default_capabilities - - dev = MCMDevice(wires=1, shots=100) - dev.operations.add("MidMeasureMP") - spy = mocker.spy(qml.dynamic_one_shot, "_transform") - - @qml.qnode(dev, interface=None, diff_method=None) - def func(): - _ = qml.measure(0) - return qml.expval(op=qml.PauliZ(0)) - - res = func() - assert spy.call_count == 1 - assert isinstance(res, float) - - def test_analytic_deprecation(): """Tests if the kwarg `analytic` is used and displays error message.""" msg = "The analytic argument has been replaced by shots=None. " diff --git a/tests/devices/test_preprocess.py b/tests/devices/test_preprocess.py index 38c3d48371e..f6f6aab592a 100644 --- a/tests/devices/test_preprocess.py +++ b/tests/devices/test_preprocess.py @@ -21,6 +21,7 @@ from pennylane.devices.preprocess import ( _operator_decomposition_gen, decompose, + mid_circuit_measurements, no_sampling, validate_adjoint_trainable_params, validate_device_wires, @@ -440,6 +441,42 @@ def test_decompose_initial_state_prep_if_requested(self, prep_op): assert new_tape[0] != prep_op +class TestMidCircuitMeasurements: + """Unit tests for the mid_circuit_measurements preprocessing transform""" + + @pytest.mark.parametrize( + "mcm_method, shots, expected_transform", + [ + ("deferred", 10, qml.defer_measurements), + ("deferred", None, qml.defer_measurements), + (None, None, qml.defer_measurements), + (None, 10, qml.dynamic_one_shot), + ("one-shot", 10, qml.dynamic_one_shot), + ], + ) + def test_mcm_method(self, mcm_method, shots, expected_transform, mocker): + """Test that the preprocessing transform adheres to the specified transform""" + dev = qml.device("default.qubit") + mcm_config = {"postselect_mode": None, "mcm_method": mcm_method} + tape = QuantumScript([qml.measurements.MidMeasureMP(0)], [], shots=shots) + spy = mocker.spy(expected_transform, "_transform") + + _, _ = mid_circuit_measurements(tape, dev, mcm_config) + spy.assert_called_once() + + def test_error_incompatible_mcm_method(self): + """Test that an error is raised if requesting the one-shot transform without shots""" + dev = qml.device("default.qubit") + shots = None + mcm_config = {"postselect_mode": None, "mcm_method": "one-shot"} + tape = QuantumScript([qml.measurements.MidMeasureMP(0)], [], shots=shots) + + with pytest.raises( + qml.QuantumFunctionError, match="dynamic_one_shot is only supported with finite shots." + ): + _, _ = mid_circuit_measurements(tape, dev, mcm_config) + + def test_validate_multiprocessing_workers_None(): """Test that validation does not fail when max_workers is None""" qs = QuantumScript( diff --git a/tests/interfaces/test_jacobian_products.py b/tests/interfaces/test_jacobian_products.py index 5796aa0cda9..07f3ccb60ab 100644 --- a/tests/interfaces/test_jacobian_products.py +++ b/tests/interfaces/test_jacobian_products.py @@ -151,7 +151,8 @@ def test_device_jacobians_repr(self): r" ExecutionConfig(grad_on_execution=None, use_device_gradient=None," r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}," - r" device_options={}, interface=None, derivative_order=1)>" + r" device_options={}, interface=None, derivative_order=1," + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" ) assert repr(jpc) == expected @@ -169,7 +170,8 @@ def test_device_jacobian_products_repr(self): r" ExecutionConfig(grad_on_execution=None, use_device_gradient=None," r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={}," - r" interface=None, derivative_order=1)>" + r" interface=None, derivative_order=1," + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" ) assert repr(jpc) == expected diff --git a/tests/test_qnode.py b/tests/test_qnode.py index b8c965b5d51..cf840c47766 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -878,11 +878,8 @@ def circuit(x, y): assert np.allclose(res, expected, atol=tol, rtol=0) @pytest.mark.parametrize( - "dev, call_count", - [ - (qml.device("default.qubit", wires=3), 2), - (qml.device("default.qubit.legacy", wires=3), 1), - ], + "dev", + [qml.device("default.qubit", wires=3), qml.device("default.qubit.legacy", wires=3)], ) @pytest.mark.parametrize("first_par", np.linspace(0.15, np.pi - 0.3, 3)) @pytest.mark.parametrize("sec_par", np.linspace(0.15, np.pi - 0.3, 3)) @@ -898,7 +895,7 @@ def circuit(x, y): ], ) def test_defer_meas_if_mcm_unsupported( - self, dev, call_count, first_par, sec_par, return_type, mv_return, mv_res, mocker + self, dev, first_par, sec_par, return_type, mv_return, mv_res, mocker ): # pylint: disable=too-many-arguments """Tests that the transform using the deferred measurement principle is applied if the device doesn't support mid-circuit measurements @@ -928,7 +925,7 @@ def conditional_ry_qnode(x, y): assert np.allclose(r1, r2[0]) assert np.allclose(r2[1], mv_res(first_par)) - assert spy.call_count == call_count # once for each preprocessing + assert spy.call_count == 2 @pytest.mark.parametrize("dev_name", ["default.qubit.legacy", "default.mixed"]) def test_dynamic_one_shot_if_mcm_unsupported(self, dev_name): @@ -1717,6 +1714,116 @@ def circuit(): assert qml.math.allclose(results, np.zeros((20, 2))) +class TestMCMConfiguration: + """Tests for MCM configuration arguments""" + + @pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.legacy"]) + def test_one_shot_error_without_shots(self, dev_name): + """Test that an error is raised if mcm_method="one-shot" with no shots""" + dev = qml.device(dev_name, wires=3) + param = np.pi / 4 + + @qml.qnode(dev, mcm_method="one-shot") + def f(x): + qml.RX(x, 0) + _ = qml.measure(0) + return qml.probs(wires=[0, 1]) + + with pytest.raises( + ValueError, match="Cannot use the 'one-shot' method for mid-circuit measurements with" + ): + _ = f(param) + + def test_invalid_mcm_method_error(self): + """Test that an error is raised if the requested mcm_method is invalid""" + shots = 100 + dev = qml.device("default.qubit", wires=3, shots=shots) + + def f(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=1) + return qml.sample(wires=[0, 1]) + + with pytest.raises(ValueError, match="Invalid mid-circuit measurements method 'foo'"): + _ = qml.QNode(f, dev, mcm_method="foo") + + def test_invalid_postselect_mode_error(self): + """Test that an error is raised if the requested postselect_mode is invalid""" + shots = 100 + dev = qml.device("default.qubit", wires=3, shots=shots) + + def f(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=1) + return qml.sample(wires=[0, 1]) + + with pytest.raises(ValueError, match="Invalid postselection mode 'foo'"): + _ = qml.QNode(f, dev, postselect_mode="foo") + + @pytest.mark.jax + @pytest.mark.parametrize("diff_method", [None, "best"]) + def test_defer_measurements_with_jit(self, diff_method, mocker): + """Test that using mcm_method="deferred" defaults to behaviour like + postselect_mode="fill-shots" when using jax jit.""" + import jax # pylint: disable=import-outside-toplevel + + shots = 100 + postselect = 1 + param = jax.numpy.array(np.pi / 2) + spy = mocker.spy(qml.defer_measurements, "_transform") + spy_one_shot = mocker.spy(qml.dynamic_one_shot, "_transform") + + dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123)) + + @qml.qnode(dev, diff_method=diff_method, mcm_method="deferred") + def f(x): + qml.RX(x, 0) + qml.measure(0, postselect=postselect) + return qml.sample(wires=0) + + f_jit = jax.jit(f) + res = f(param) + res_jit = f_jit(param) + + assert spy.call_count > 0 + spy_one_shot.assert_not_called() + + assert len(res) < shots + assert len(res_jit) == shots + assert qml.math.allclose(res, postselect) + assert qml.math.allclose(res_jit, postselect) + + @pytest.mark.jax + # @pytest.mark.parametrize("diff_method", [None, "best"]) + @pytest.mark.parametrize("diff_method", ["best"]) + def test__deferred_hw_like_error_with_jit(self, diff_method): + """Test that an error is raised if attempting to use postselect_mode="hw-like" + with jax jit with mcm_method="deferred".""" + import jax # pylint: disable=import-outside-toplevel + + shots = 100 + postselect = 1 + param = jax.numpy.array(np.pi / 2) + + dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123)) + + @qml.qnode(dev, diff_method=diff_method, mcm_method="deferred", postselect_mode="hw-like") + def f(x): + qml.RX(x, 0) + qml.measure(0, postselect=postselect) + return qml.sample(wires=0) + + f_jit = jax.jit(f) + + # Checking that an error is not raised without jit + _ = f(param) + + with pytest.raises( + ValueError, match="Using postselect_mode='hw-like' is not supported with jax-jit." + ): + _ = f_jit(param) + + class TestTapeExpansion: """Test that tape expansion within the QNode works correctly""" diff --git a/tests/test_qnode_legacy.py b/tests/test_qnode_legacy.py index 3ac017a8f9d..34f7623c37b 100644 --- a/tests/test_qnode_legacy.py +++ b/tests/test_qnode_legacy.py @@ -1075,7 +1075,7 @@ def test_no_defer_measurements_if_supported(self, mocker): QNode construction if the device supports mid-circuit measurements.""" dev = qml.device("default.qubit.legacy", wires=3) mocker.patch.object(qml.Device, "_capabilities", {"supports_mid_measure": True}) - spy = mocker.spy(qml, "defer_measurements") + spy = mocker.spy(qml.defer_measurements, "_transform") @qml.qnode(dev) def circuit(): @@ -1138,23 +1138,6 @@ def conditional_ry_qnode(x, y): assert np.allclose(r2[1], mv_res(first_par)) assert spy.call_count == 3 if dev.name == "defaut.qubit" else 1 - def test_drawing_has_deferred_measurements(self): - """Test that `qml.draw` with qnodes uses defer_measurements - to draw circuits with mid-circuit measurements.""" - dev = qml.device("default.qubit.legacy", wires=2) - - @qml.qnode(dev) - def circuit(x): - qml.RX(x, wires=0) - m = qml.measure(0) - qml.cond(m, qml.PauliX)(wires=1) - return qml.expval(qml.PauliZ(wires=1)) - - res = qml.draw(circuit)("x") - expected = "0: ──RX(x)─╭●─┤ \n1: ────────╰X─┤ " - - assert res == expected - @pytest.mark.parametrize("basis_state", [[1, 0], [0, 1]]) def test_sampling_with_mcm(self, basis_state, mocker): """Tests that a QNode with qml.sample and mid-circuit measurements @@ -1179,11 +1162,11 @@ def conditional_ry_qnode(x): qml.cond(m_0, qml.RY)(x, wires=1) return qml.sample(qml.PauliZ(1)) - spy = mocker.spy(qml, "defer_measurements") + spy = mocker.spy(qml.defer_measurements, "_transform") r1 = cry_qnode(first_par) r2 = conditional_ry_qnode(first_par) assert np.allclose(r1, r2) - spy.assert_called_once() + spy.assert_called() @pytest.mark.tf @pytest.mark.parametrize("interface", ["tf", "auto"]) diff --git a/tests/test_qubit_device.py b/tests/test_qubit_device.py index 611bae8e93f..f8a10f26ac4 100644 --- a/tests/test_qubit_device.py +++ b/tests/test_qubit_device.py @@ -14,6 +14,7 @@ """ Unit tests for the :mod:`pennylane` :class:`QubitDevice` class. """ +import copy from random import random import numpy as np @@ -1147,6 +1148,62 @@ def test_defines_correct_capabilities(self): assert capabilities == QubitDevice.capabilities() +class TestNativeMidCircuitMeasurements: + """Unit tests for mid-circuit measurements related functionality""" + + class MCMDevice(qml.devices.DefaultQubitLegacy): + def apply(self, *args, **kwargs): + for op in args[0]: + if isinstance(op, qml.measurements.MidMeasureMP): + kwargs["mid_measurements"][op] = 0 + + @classmethod + def capabilities(cls): + default_capabilities = copy.copy(qml.devices.DefaultQubitLegacy.capabilities()) + default_capabilities["supports_mid_measure"] = True + return default_capabilities + + def test_qnode_native_mcm(self, mocker): + """Tests that the legacy devices may support native MCM execution via the dynamic_one_shot transform.""" + + dev = self.MCMDevice(wires=1, shots=100) + dev.operations.add("MidMeasureMP") + spy = mocker.spy(qml.dynamic_one_shot, "_transform") + + @qml.qnode(dev, interface=None, diff_method=None) + def func(): + _ = qml.measure(0) + return qml.expval(op=qml.PauliZ(0)) + + res = func() + assert spy.call_count == 1 + assert isinstance(res, float) + + @pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) + def test_postselect_mode_propagates_to_execute(self, monkeypatch, postselect_mode): + """Test that the specified postselect mode propagates to execution as expected.""" + dev = self.MCMDevice(wires=1, shots=100) + dev.operations.add("MidMeasureMP") + pm_propagated = False + + def new_apply(*args, **kwargs): # pylint: disable=unused-argument + nonlocal pm_propagated + pm_propagated = kwargs.get("postselect_mode", -1) == postselect_mode + + @qml.qnode(dev, postselect_mode=postselect_mode) + def func(): + _ = qml.measure(0, postselect=1) + return qml.expval(op=qml.PauliZ(0)) + + with monkeypatch.context() as m: + m.setattr(dev, "apply", new_apply) + with pytest.raises(Exception): + # Error expected as mocked apply method does not adhere to expected output. + func() + + assert pm_propagated is True + + class TestExecution: """Tests for the execute method""" diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py index 4d4faea65db..e1c7353e024 100644 --- a/tests/transforms/test_defer_measurements.py +++ b/tests/transforms/test_defer_measurements.py @@ -105,6 +105,30 @@ def circ(): _ = circ() +@pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) +def test_postselect_mode(postselect_mode, mocker): + """Test that invalid shots are discarded if requested""" + shots = 100 + postselect_value = 1 + dev = qml.device("default.qubit", shots=shots) + spy = mocker.spy(qml.defer_measurements, "_transform") + + @qml.qnode(dev, postselect_mode=postselect_mode, mcm_method="deferred") + def f(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=postselect_value) + return qml.sample(wires=[0]) + + res = f(np.pi / 4) + spy.assert_called_once() + + if postselect_mode == "hw-like": + assert len(res) < shots + else: + assert len(res) == shots + assert np.allclose(res, postselect_value) + + @pytest.mark.parametrize( "mp, err_msg", [ diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index ce9d67a429b..6c908840359 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -61,6 +61,55 @@ def _(): return qml.probs(wires=[0]) +@pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) +def test_postselect_mode(postselect_mode, mocker): + """Test that invalid shots are discarded if requested""" + shots = 100 + dev = qml.device("default.qubit", shots=shots) + spy = mocker.spy(qml, "dynamic_one_shot") + + @qml.qnode(dev, postselect_mode=postselect_mode) + def f(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=1) + return qml.sample(wires=[0, 1]) + + res = f(np.pi / 2) + spy.assert_called_once() + + if postselect_mode == "hw-like": + assert len(res) < shots + else: + assert len(res) == shots + assert np.all(res != np.iinfo(np.int32).min) + + +@pytest.mark.jax +@pytest.mark.parametrize("use_jit", [True, False]) +@pytest.mark.parametrize("diff_method", [None, "best"]) +def test_hw_like_with_jax(use_jit, diff_method): + """Test that invalid shots are replaced with INTEGER_MIN_VAL if + postselect_mode="hw-like" with JAX""" + import jax # pylint: disable=import-outside-toplevel + + shots = 10 + dev = qml.device("default.qubit", shots=shots, seed=jax.random.PRNGKey(123)) + + @qml.qnode(dev, postselect_mode="hw-like", diff_method=diff_method) + def f(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=1) + return qml.sample(wires=[0, 1]) + + if use_jit: + f = jax.jit(f) + + res = f(jax.numpy.array(np.pi / 2)) + + assert len(res) == shots + assert np.any(res == np.iinfo(np.int32).min) + + def test_unsupported_measurements(): """Test that using unsupported measurements raises an error.""" tape = qml.tape.QuantumScript([MidMeasureMP(0)], [qml.state()]) diff --git a/tests/transforms/test_tape_expand.py b/tests/transforms/test_tape_expand.py index 55f65a7cabc..715320b48e8 100644 --- a/tests/transforms/test_tape_expand.py +++ b/tests/transforms/test_tape_expand.py @@ -931,7 +931,6 @@ def circuit(): _ = circuit() decomp_ops = circuit.tape.operations - print(decomp_ops) assert len(decomp_ops) == 4 if shots is None else 5 assert decomp_ops[0].name == "RZ" @@ -940,5 +939,10 @@ def circuit(): assert decomp_ops[1].name == "RY" assert np.isclose(decomp_ops[1].parameters[0], np.pi / 2) - assert decomp_ops[2].name == "CNOT" - assert decomp_ops[3].name == "CNOT" + if shots: + assert decomp_ops[2].name == "MidMeasureMP" + assert decomp_ops[3].name == "CNOT" + assert decomp_ops[4].name == "MidMeasureMP" + else: + assert decomp_ops[2].name == "CNOT" + assert decomp_ops[3].name == "CNOT"