From 44018e9ae06ae11590288ffef2f8ec348560bb28 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Mon, 22 Apr 2024 16:33:55 -0400 Subject: [PATCH] Use SampleMPs in dynamic_one_shot (#5486) ### Before submitting Please complete the following checklist when submitting a PR: - [x] All new features must include a unit test. If you've fixed a bug or added code that should be tested, add a test to the test directory! - [x] All new functions and code must be clearly commented and documented. If you do make documentation changes, make sure that the docs build and render correctly by running `make docs`. - [x] Ensure that the test suite passes, by running `make test`. - [x] Add a new entry to the `doc/releases/changelog-dev.md` file, summarizing the change, and including a link back to the PR. - [x] The PennyLane source code conforms to [PEP8 standards](https://www.python.org/dev/peps/pep-0008/). We check all of our code against [Pylint](https://www.pylint.org/). To lint modified files, simply `pip install pylint`, and then run `pylint pennylane/path/to/file.py`. When all the above are checked, delete everything above the dashed line and fill in the pull request template. ------------------------------------------------------------------------------------------------------------ **Context:** The native MCM workflow breaks the device API where a sequence of MeasurementProcess objects is expected in the output. **Description of the Change:** Introduce SampleMPs in the auxiliary tape. Pass the mid_measurements dictionary around simulate and sampling to return the correct sample measurements. Modify the dynamic_one_shot transform post-processing function accordingly. **Benefits:** Conform to current API. Road to jax.jit support. **Possible Drawbacks:** Ad hoc post-processing required in measure_with_samples. **Related GitHub Issues:** [sc-60945] --------- Co-authored-by: Christina Lee --- doc/releases/changelog-dev.md | 21 ++- pennylane/_qubit_device.py | 25 +-- pennylane/devices/qubit/apply_operation.py | 21 +-- pennylane/devices/qubit/sampling.py | 30 +++- pennylane/devices/qubit/simulate.py | 37 ++-- pennylane/tape/tape.py | 33 +++- pennylane/transforms/dynamic_one_shot.py | 51 +++++- .../test_default_qubit_native_mcm.py | 64 ++++--- tests/devices/qubit/test_simulate.py | 17 +- tests/devices/test_default_qubit_legacy.py | 7 +- tests/math/test_functions.py | 36 ++++ tests/transforms/test_dynamic_one_shot.py | 168 ++++++++++++++++++ 12 files changed, 399 insertions(+), 111 deletions(-) create mode 100644 tests/transforms/test_dynamic_one_shot.py diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 72e7bcfef62..87ed03efc7d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -161,7 +161,7 @@ >>> circuit() tensor([1.+6.123234e-17j, 0.-6.123234e-17j], requires_grad=True) ``` - + * The `qml.AmplitudeAmplification` operator is introduced, which is a high-level interface for amplitude amplification and its variants. [(#5160)](https://github.com/PennyLaneAI/pennylane/pull/5160) @@ -185,7 +185,7 @@ return qml.probs(wires=range(3)) ``` - + ```pycon >>> print(np.round(circuit(), 3)) [0.013, 0.013, 0.91, 0.013, 0.013, 0.013, 0.013, 0.013] @@ -212,7 +212,7 @@ but for usage with new operator arithmetic. [(#5216)](https://github.com/PennyLaneAI/pennylane/pull/5216) -* The `qml.TrotterProduct` operator now supports error estimation functionality. +* The `qml.TrotterProduct` operator now supports error estimation functionality. [(#5384)](https://github.com/PennyLaneAI/pennylane/pull/5384) ```pycon @@ -245,18 +245,18 @@ * The `molecular_hamiltonian` function calls `PySCF` directly when `method='pyscf'` is selected. [(#5118)](https://github.com/PennyLaneAI/pennylane/pull/5118) -* The generators in the source code return operators consistent with the global setting for - `qml.operator.active_new_opmath()` wherever possible. `Sum`, `SProd` and `Prod` instances - will be returned even after disabling the new operator arithmetic in cases where they offer +* The generators in the source code return operators consistent with the global setting for + `qml.operator.active_new_opmath()` wherever possible. `Sum`, `SProd` and `Prod` instances + will be returned even after disabling the new operator arithmetic in cases where they offer additional functionality not available using legacy operators. [(#5253)](https://github.com/PennyLaneAI/pennylane/pull/5253) [(#5410)](https://github.com/PennyLaneAI/pennylane/pull/5410) - [(#5411)](https://github.com/PennyLaneAI/pennylane/pull/5411) + [(#5411)](https://github.com/PennyLaneAI/pennylane/pull/5411) [(#5421)](https://github.com/PennyLaneAI/pennylane/pull/5421) * Upgraded `null.qubit` to the new device API. Also, added support for all measurements and various modes of differentiation. [(#5211)](https://github.com/PennyLaneAI/pennylane/pull/5211) - + * `ApproxTimeEvolution` is now compatible with any operator that defines a `pauli_rep`. [(#5362)](https://github.com/PennyLaneAI/pennylane/pull/5362) @@ -338,6 +338,9 @@

Breaking changes 💔

+* Use `SampleMP`s in the `dynamic_one_shot` transform to get back the values of the mid-circuit measurements. + [(#5486)](https://github.com/PennyLaneAI/pennylane/pull/5486) + * Operator dunder methods now combine like-operator arithmetic classes via `lazy=False`. This reduces the chance of `RecursionError` and makes nested operators easier to work with. [(#5478)](https://github.com/PennyLaneAI/pennylane/pull/5478) @@ -359,7 +362,7 @@ * `qml.pauli.pauli_mult` and `qml.pauli.pauli_mult_with_phase` are now removed. Instead, you should use `qml.simplify(qml.prod(pauli_1, pauli_2))` to get the reduced operator. [(#5324)](https://github.com/PennyLaneAI/pennylane/pull/5324) - + ```pycon >>> op = qml.simplify(qml.prod(qml.PauliX(0), qml.PauliZ(0))) >>> op diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py index a922769351f..ba8b031219a 100644 --- a/pennylane/_qubit_device.py +++ b/pennylane/_qubit_device.py @@ -285,12 +285,6 @@ def execute(self, circuit, **kwargs): ) if has_mcm: mid_measurements = kwargs["mid_measurements"] - mid_values = np.array(tuple(mid_measurements.values())) - if np.any(mid_values == -1): - for k, v in tuple(mid_measurements.items()): - if v == -1: - mid_measurements.pop(k) - return None, mid_measurements # generate computational basis samples sample_type = (SampleMP, CountsMP, ClassicalShadowMP, ShadowExpvalMP) @@ -308,13 +302,24 @@ def execute(self, circuit, **kwargs): self.apply([qml.adjoint(g, lazy=False) for g in reversed(diagonalizing_gates)]) # compute the required statistics + if has_mcm: + n_mcms = len(mid_measurements) + stat_circuit = qml.tape.QuantumScript( + circuit.operations, + circuit.measurements[0:-n_mcms], + shots=1, + trainable_params=circuit.trainable_params, + ) + else: + stat_circuit = circuit if self._shot_vector is not None: - results = self.shot_vec_statistics(circuit) + results = self.shot_vec_statistics(stat_circuit) else: - results = self.statistics(circuit) + results = self.statistics(stat_circuit) + if has_mcm: + results.extend(list(mid_measurements.values())) single_measurement = len(circuit.measurements) == 1 - results = results[0] if single_measurement else tuple(results) # increment counter for number of executions of qubit device self._num_executions += 1 @@ -336,7 +341,7 @@ def execute(self, circuit, **kwargs): ) self.tracker.record() - return (results, mid_measurements) if has_mcm else results + return results def shot_vec_statistics(self, circuit: QuantumTape): """Process measurement results from circuit execution using a device diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index 61ea0956a47..f1597664509 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -16,12 +16,12 @@ from functools import singledispatch from string import ascii_letters as alphabet + import numpy as np import pennylane as qml - from pennylane import math -from pennylane.measurements import MidMeasureMP +from pennylane.measurements import MidMeasureMP, Shots from pennylane.ops import Conditional SQRT2INV = 1 / math.sqrt(2) @@ -261,21 +261,16 @@ def apply_mid_measure( if is_state_batched: raise ValueError("MidMeasureMP cannot be applied to batched states.") if not np.allclose(np.linalg.norm(state), 1.0): - mid_measurements[op] = 0 + mid_measurements[op] = -1 return np.zeros_like(state) wire = op.wires - probs = qml.devices.qubit.measure(qml.probs(wire), state) - - try: # pragma: no cover - sample = np.random.binomial(1, probs[1]) - except ValueError as e: # pragma: no cover - if probs[1] > 1: # MachEps error, safe to catch - sample = np.random.binomial(1, np.round(probs[1], 15)) - else: # Other general error, continue to fail - raise e - + sample = qml.devices.qubit.sampling.measure_with_samples( + [qml.sample(wires=wire)], state, Shots(1) + ) + sample = int(sample[0]) mid_measurements[op] = sample if op.postselect is not None and sample != op.postselect: + mid_measurements[op] = -1 return np.zeros_like(state) axis = wire.toarray()[0] slices = [slice(None)] * qml.math.ndim(state) diff --git a/pennylane/devices/qubit/sampling.py b/pennylane/devices/qubit/sampling.py index 4264b0fa977..e18b5dcec35 100644 --- a/pennylane/devices/qubit/sampling.py +++ b/pennylane/devices/qubit/sampling.py @@ -12,20 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. """Functions to sample a state.""" -from typing import List, Union, Tuple +from typing import List, Tuple, Union import numpy as np + import pennylane as qml -from pennylane.ops import Sum, Hamiltonian, LinearCombination from pennylane.measurements import ( - SampleMeasurement, - Shots, - ExpectationMP, ClassicalShadowMP, - ShadowExpvalMP, CountsMP, + ExpectationMP, + SampleMeasurement, + ShadowExpvalMP, + Shots, ) +from pennylane.ops import Hamiltonian, LinearCombination, Sum from pennylane.typing import TensorLike + from .apply_operation import apply_operation from .measure import flatten_state @@ -165,12 +167,13 @@ def _apply_diagonalizing_gates( # pylint:disable = too-many-arguments def measure_with_samples( - mps: List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]], + measurements: List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]], state: np.ndarray, shots: Shots, is_state_batched: bool = False, rng=None, prng_key=None, + mid_measurements: dict = None, ) -> List[TensorLike]: """ Returns the samples of the measurement process performed on the given state. @@ -178,7 +181,7 @@ def measure_with_samples( have already been mapped to integer wires used in the device. Args: - mp (List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]]): + measurements (List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]]): The sample measurements to perform state (np.ndarray[complex]): The state vector to sample from shots (Shots): The number of samples to take @@ -188,15 +191,22 @@ def measure_with_samples( If no value is provided, a default RNG will be used. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. + mid_measurements (None, dict): Dictionary of mid-circuit measurements Returns: List[TensorLike[Any]]: Sample measurement results """ + # last N measurements are sampling MCMs in ``dynamic_one_shot`` execution mode + mps = measurements[0 : -len(mid_measurements)] if mid_measurements else measurements + skip_measure = any(v == -1 for v in mid_measurements.values()) if mid_measurements else False groups, indices = _group_measurements(mps) all_res = [] for group in groups: + if skip_measure: + all_res.extend([None] * len(group)) + continue if isinstance(group[0], ExpectationMP) and isinstance( group[0].obs, (Hamiltonian, LinearCombination) ): @@ -223,6 +233,10 @@ def measure_with_samples( res for _, res in sorted(list(enumerate(all_res)), key=lambda r: flat_indices[r[0]]) ) + # append MCM samples + if mid_measurements: + sorted_res += tuple(mid_measurements.values()) + # put the shot vector axis before the measurement axis if shots.has_partitioned_shots: sorted_res = tuple(zip(*sorted_res)) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 0605d267d67..c31043e6bc1 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -15,21 +15,18 @@ # pylint: disable=protected-access from typing import Optional -from numpy.random import default_rng import numpy as np +from numpy.random import default_rng import pennylane as qml -from pennylane.measurements import ( - MidMeasureMP, -) +from pennylane.measurements import MidMeasureMP from pennylane.typing import Result -from .initialize_state import create_initial_state from .apply_operation import apply_operation +from .initialize_state import create_initial_state from .measure import measure from .sampling import jax_random_split, measure_with_samples - INTERFACE_TO_LIKE = { # map interfaces known by autoray to themselves None: None, @@ -153,7 +150,10 @@ def get_final_state(circuit, debugger=None, interface=None, mid_measurements=Non return state, is_state_batched -def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=None) -> Result: +# pylint: disable=too-many-arguments +def measure_final_state( + circuit, state, is_state_batched, rng=None, prng_key=None, mid_measurements: dict = None +) -> Result: """ Perform the measurements required by the circuit on the provided state. @@ -170,6 +170,7 @@ def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=Non the key to the JAX pseudo random number generator. Only for simulation using JAX. If None, the default ``sample_state`` function and a ``numpy.random.default_rng`` will be for sampling. + mid_measurements (None, dict): Dictionary of mid-circuit measurements Returns: Tuple[TensorLike]: The measurement results @@ -177,8 +178,11 @@ def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=Non circuit = circuit.map_to_standard_wires() + # analytic case + if not circuit.shots: - # analytic case + if mid_measurements is not None: + raise TypeError("Native mid-circuit measurements are only supported with finite shots.") if len(circuit.measurements) == 1: return measure(circuit.measurements[0], state, is_state_batched=is_state_batched) @@ -197,6 +201,7 @@ def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=Non is_state_batched=is_state_batched, rng=rng, prng_key=prng_key, + mid_measurements=mid_measurements, ) if len(circuit.measurements) == 1: @@ -283,13 +288,15 @@ def simulate_one_shot_native_mcm( dict: The mid-circuit measurement results of the simulation """ _, key = jax_random_split(prng_key) - mcm_dict = {} + mid_measurements = {} state, is_state_batched = get_final_state( - circuit, debugger=debugger, interface=interface, mid_measurements=mcm_dict + circuit, debugger=debugger, interface=interface, mid_measurements=mid_measurements ) - if not np.allclose(np.linalg.norm(state), 1.0): - return None, mcm_dict - return ( - measure_final_state(circuit, state, is_state_batched, rng=rng, prng_key=key), - mcm_dict, + return measure_final_state( + circuit, + state, + is_state_batched, + rng=rng, + prng_key=key, + mid_measurements=mid_measurements, ) diff --git a/pennylane/tape/tape.py b/pennylane/tape/tape.py index 56520e2b45b..39ae262f430 100644 --- a/pennylane/tape/tape.py +++ b/pennylane/tape/tape.py @@ -19,10 +19,16 @@ from threading import RLock import pennylane as qml -from pennylane.measurements import CountsMP, ProbabilityMP, SampleMP, MeasurementProcess +from pennylane.measurements import ( + CountsMP, + MeasurementProcess, + MidMeasureMP, + ProbabilityMP, + SampleMP, +) from pennylane.operation import DecompositionUndefinedError, Operator, StatePrepBase -from pennylane.queuing import AnnotatedQueue, QueuingManager, process_queue from pennylane.pytrees import register_pytree +from pennylane.queuing import AnnotatedQueue, QueuingManager, process_queue from .qscript import QuantumScript @@ -40,29 +46,40 @@ def _err_msg_for_some_meas_not_qwc(measurements): ) -def _validate_computational_basis_sampling(measurements): +def _validate_computational_basis_sampling(tape): """Auxiliary function for validating computational basis state sampling with other measurements considering the qubit-wise commutativity relation.""" + measurements = tape.measurements + n_meas = len(measurements) + n_mcms = sum(isinstance(op, MidMeasureMP) for op in tape.operations) non_comp_basis_sampling_obs = [] comp_basis_sampling_obs = [] - for o in measurements: + comp_basis_indices = [] + for i, o in enumerate(measurements): if o.samples_computational_basis: comp_basis_sampling_obs.append(o) + comp_basis_indices.append(i) else: non_comp_basis_sampling_obs.append(o) if non_comp_basis_sampling_obs: all_wires = [] empty_wires = qml.wires.Wires([]) - for idx, cb_obs in enumerate(comp_basis_sampling_obs): + for idx, (cb_obs, global_idx) in enumerate( + zip(comp_basis_sampling_obs, comp_basis_indices) + ): if cb_obs.wires == empty_wires: all_wires = qml.wires.Wires.all_wires([m.wires for m in measurements]) break - - all_wires.append(cb_obs.wires) + if global_idx < n_meas - n_mcms: + all_wires.append(cb_obs.wires) if idx == len(comp_basis_sampling_obs) - 1: all_wires = qml.wires.Wires.all_wires(all_wires) + # This happens when a MeasurementRegisterMP is the only computational basis state measurement + if all_wires == empty_wires: + return + with QueuingManager.stop_recording(): # stop recording operations - the constructed operator is just aux pauliz_for_cb_obs = ( qml.Z(all_wires) @@ -176,7 +193,7 @@ def stop_at(obj): # pylint: disable=unused-argument # rotations and the observables updated to the computational basis. Note that this # expansion acts on the original tape in place. if tape.samples_computational_basis and len(tape.measurements) > 1: - _validate_computational_basis_sampling(tape.measurements) + _validate_computational_basis_sampling(tape) diagonalizing_gates, diagonal_measurements = rotations_and_diagonal_measurements(tape) for queue, new_queue in [ diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index bac1b05391f..98625aa6042 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -14,10 +14,12 @@ """ Contains the batch dimension transform. """ +import warnings + # pylint: disable=import-outside-toplevel from collections import Counter +from itertools import compress from typing import Callable, Sequence -import warnings import numpy as np @@ -25,6 +27,7 @@ from pennylane.measurements import ( CountsMP, ExpectationMP, + MeasurementValue, MidMeasureMP, ProbabilityMP, SampleMP, @@ -91,6 +94,9 @@ def func(x, y): "measurements." ) + if not tape.shots: + raise qml.QuantumFunctionError("dynamic_one_shot is only supported with finite shots.") + samples_present = any(isinstance(mp, SampleMP) for mp in tape.measurements) postselect_present = any( op.postselect is not None for op in tape.operations if isinstance(op, MidMeasureMP) @@ -137,16 +143,38 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None): ) del results[0:s] return tuple(final_results) + all_mcms = [op for op in aux_tapes[0].operations if isinstance(op, MidMeasureMP)] + n_mcms = len(all_mcms) + post_process_tape = qml.tape.QuantumScript( + aux_tapes[0].operations, + aux_tapes[0].measurements[0:-n_mcms], + shots=aux_tapes[0].shots, + trainable_params=aux_tapes[0].trainable_params, + ) + single_measurement = ( + len(post_process_tape.measurements) == 0 and len(aux_tapes[0].measurements) == 1 + ) + mcm_samples = np.zeros((len(results), n_mcms), dtype=np.int64) + for i, res in enumerate(results): + mcm_samples[i, :] = [res] if single_measurement else res[-n_mcms::] + mcm_mask = qml.math.all(mcm_samples != -1, axis=1) + mcm_samples = mcm_samples[mcm_mask, :] + results = list(compress(results, mcm_mask)) # The following code assumes no broadcasting and no shot vectors. The above code should # handle those cases all_shot_meas, list_mcm_values_dict, valid_shots = None, [], 0 - for res in results: - one_shot_meas, mcm_values_dict = res - if one_shot_meas is None: - continue + for i, res in enumerate(results): + samples = [res] if single_measurement else res[-n_mcms::] valid_shots += 1 - all_shot_meas = accumulate_native_mcm(aux_tapes[0], all_shot_meas, one_shot_meas) + mcm_values_dict = dict((k, v) for k, v in zip(all_mcms, samples)) + if len(post_process_tape.measurements) == 0: + one_shot_meas = [] + elif len(post_process_tape.measurements) == 1: + one_shot_meas = res[0] + else: + one_shot_meas = res[0:-n_mcms] + all_shot_meas = accumulate_native_mcm(post_process_tape, all_shot_meas, one_shot_meas) list_mcm_values_dict.append(mcm_values_dict) if not valid_shots: warnings.warn( @@ -203,6 +231,10 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): new_measurements.append(SampleMP(obs=m.obs)) else: new_measurements.append(m) + for op in circuit: + if isinstance(op, MidMeasureMP): + new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) + return qml.tape.QuantumScript( circuit.operations, new_measurements, shots=1, trainable_params=circuit.trainable_params ) @@ -327,11 +359,14 @@ def gather_mcm(measurement, samples): mcm_samples = np.concatenate(mcm_samples, axis=1) meas_tmp = measurement.__class__(wires=wires) return meas_tmp.process_samples(mcm_samples, wire_order=wires) + mcm_samples = np.zeros((len(samples), 1), dtype=np.int64) if isinstance(measurement, ProbabilityMP): - mcm_samples = np.array([dct[mv.measurements[0]] for dct in samples]).reshape((-1, 1)) + for i, dct in enumerate(samples): + mcm_samples[i, 0] = dct[mv.measurements[0]] use_as_is = True else: - mcm_samples = np.array([mv.concretize(dct) for dct in samples]).reshape((-1, 1)) + for i, dct in enumerate(samples): + mcm_samples[i, 0] = mv.concretize(dct) use_as_is = mv.branches == {(0,): 0, (1,): 1} if use_as_is: wires, meas_tmp = mv.wires, measurement diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index b2589096d78..82717247eba 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -13,18 +13,14 @@ # limitations under the License. """Tests for default qubit preprocessing.""" from functools import reduce -from typing import Sequence, Iterable +from typing import Iterable, Sequence -from flaky import flaky import numpy as np import pytest +from flaky import flaky import pennylane as qml -from pennylane.devices.qubit.apply_operation import apply_mid_measure, MidMeasureMP -from pennylane.transforms.dynamic_one_shot import ( - accumulate_native_mcm, - parse_native_mid_circuit_measurements, -) +from pennylane.devices.qubit.apply_operation import MidMeasureMP, apply_mid_measure pytestmark = pytest.mark.slow @@ -143,21 +139,13 @@ def test_apply_mid_measure(): m0 = MidMeasureMP(0, postselect=1) mid_measurements = {} state = apply_mid_measure(m0, np.zeros(2), mid_measurements=mid_measurements) - assert mid_measurements[m0] == 0 + assert mid_measurements[m0] == -1 assert np.allclose(state, 0.0) state = apply_mid_measure(m0, np.array([1, 0]), mid_measurements=mid_measurements) - assert mid_measurements[m0] == 0 + assert mid_measurements[m0] == -1 assert np.allclose(state, 0.0) -def test_accumulate_native_mcm_unsupported_error(): - with pytest.raises( - TypeError, - match=f"Native mid-circuit measurement mode does not support {type(qml.var(qml.PauliZ(0))).__name__}", - ): - accumulate_native_mcm(qml.tape.QuantumScript([], [qml.var(qml.PauliZ(0))]), [None], [None]) - - def test_all_invalid_shots_circuit(): dev = qml.device("default.qubit") @@ -195,23 +183,6 @@ def circuit_mcm(): assert np.all(np.isnan(r2)) -@pytest.mark.parametrize( - "measurement", - [ - qml.state(), - qml.density_matrix(0), - qml.vn_entropy(0), - qml.mutual_info(0, 1), - qml.purity(0), - qml.classical_shadow(0), - ], -) -def test_parse_native_mid_circuit_measurements_unsupported_meas(measurement): - circuit = qml.tape.QuantumScript([qml.RX(1, 0)], [measurement]) - with pytest.raises(TypeError, match="Native mid-circuit measurement mode does not support"): - parse_native_mid_circuit_measurements(circuit, None, None) - - def test_unsupported_measurement(): dev = qml.device("default.qubit", shots=1000) params = np.pi / 4 * np.ones(2) @@ -311,6 +282,31 @@ def func(x, y, z): validate_measurements(measure_f, shots, results1, results2) +@flaky(max_runs=5) +@pytest.mark.parametrize("postselect", [None, 0, 1]) +@pytest.mark.parametrize("reset", [False, True]) +def test_single_mcm_multiple_measure_obs(postselect, reset): + """Tests that DefaultQubit handles a circuit with a single mid-circuit measurement and a + conditional gate. Multiple measurements of common observables are performed at the end.""" + + dev = qml.device("default.qubit", shots=5000) + params = [np.pi / 7, np.pi / 6, -np.pi / 5] + + @qml.qnode(dev) + def func(x, y, z): + obs_tape(x, y, z, reset=reset, postselect=postselect) + return qml.counts(qml.PauliZ(0)), qml.expval(qml.PauliY(1)) + + func1 = func + func2 = qml.defer_measurements(func) + + results1 = func1(*params) + results2 = func2(*params) + + for measure_f, res1, res2 in zip([qml.counts, qml.expval], results1, results2): + validate_measurements(measure_f, 5000, res1, res2) + + @flaky(max_runs=5) @pytest.mark.parametrize("shots", [None, 3000, [3000, 3001]]) @pytest.mark.parametrize("postselect", [None, 0, 1]) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 3f226f02e6d..5e3354ff7b4 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -13,12 +13,11 @@ # limitations under the License. """Unit tests for simulate in devices/qubit.""" -import pytest - import numpy as np +import pytest import pennylane as qml -from pennylane.devices.qubit import simulate, get_final_state, measure_final_state +from pennylane.devices.qubit import get_final_state, measure_final_state, simulate class TestCurrentlyUnsupportedCases: @@ -66,6 +65,18 @@ def test_basis_state(self): class TestBasicCircuit: """Tests a basic circuit with one rx gate and two simple expectation values.""" + def test_analytic_mid_meas_raise(self): + """Test measure_final_state raises an error when getting a mid-measurement dictionary.""" + phi = np.array(0.397) + qs = qml.tape.QuantumScript( + [qml.RX(phi, wires=0)], [qml.expval(qml.PauliY(0)), qml.expval(qml.PauliZ(0))] + ) + state, is_state_batched = get_final_state(qs) + with pytest.raises( + TypeError, match="Native mid-circuit measurements are only supported with finite shots." + ): + _ = measure_final_state(qs, state, is_state_batched, mid_measurements={}) + def test_basic_circuit_numpy(self): """Test execution with a basic circuit.""" phi = np.array(0.397) diff --git a/tests/devices/test_default_qubit_legacy.py b/tests/devices/test_default_qubit_legacy.py index 1450edc74cd..a5b4df81cec 100644 --- a/tests/devices/test_default_qubit_legacy.py +++ b/tests/devices/test_default_qubit_legacy.py @@ -18,10 +18,9 @@ # pylint: disable=protected-access,cell-var-from-loop import cmath import copy - import math - from functools import partial + import pytest import pennylane as qml @@ -95,7 +94,9 @@ def test_qnode_native_mcm(mocker): class MCMDevice(DefaultQubitLegacy): def apply(self, *args, **kwargs): - pass + for op in args[0]: + if isinstance(op, qml.measurements.MidMeasureMP): + kwargs["mid_measurements"][op] = 0 @classmethod def capabilities(cls): diff --git a/tests/math/test_functions.py b/tests/math/test_functions.py index 81ae38f2fa9..043cd595906 100644 --- a/tests/math/test_functions.py +++ b/tests/math/test_functions.py @@ -186,6 +186,42 @@ def test_allequal(t1, t2): assert res == expected +test_all_vectors = [ + ((False, False, False), False), + ((True, True, False), False), + ((True, True, True), True), +] + + +@pytest.mark.parametrize( + "array_fn", [tuple, list, onp.array, np.array, torch.tensor, tf.Variable, tf.constant] +) +@pytest.mark.parametrize("t1, expected", test_all_vectors) +def test_all(array_fn, t1, expected): + """Test that the all function works for a variety of inputs.""" + res = fn.all(array_fn(t1)) + + assert res == expected + + +test_any_vectors = [ + ((False, False, False), False), + ((True, True, False), True), + ((True, True, True), True), +] + + +@pytest.mark.parametrize( + "array_fn", [tuple, list, onp.array, np.array, torch.tensor, tf.Variable, tf.constant] +) +@pytest.mark.parametrize("t1, expected", test_any_vectors) +def test_any(array_fn, t1, expected): + """Test that the any function works for a variety of inputs.""" + res = fn.any(array_fn(t1)) + + assert res == expected + + @pytest.mark.parametrize( "t1,t2", list( diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py new file mode 100644 index 00000000000..282d6fe0523 --- /dev/null +++ b/tests/transforms/test_dynamic_one_shot.py @@ -0,0 +1,168 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the transform implementing the deferred measurement principle. +""" +# pylint: disable=too-few-public-methods, too-many-arguments + +import numpy as np +import pytest + +import pennylane as qml +from pennylane.measurements import ( + CountsMP, + ExpectationMP, + MeasurementValue, + MidMeasureMP, + ProbabilityMP, + SampleMP, +) +from pennylane.transforms.dynamic_one_shot import ( + accumulate_native_mcm, + parse_native_mid_circuit_measurements, +) + + +def test_accumulate_native_mcm_unsupported_error(): + with pytest.raises( + TypeError, + match=f"Native mid-circuit measurement mode does not support {type(qml.var(qml.PauliZ(0))).__name__}", + ): + accumulate_native_mcm(qml.tape.QuantumScript([], [qml.var(qml.PauliZ(0))]), [None], [None]) + + +@pytest.mark.parametrize( + "measurement", + [ + qml.state(), + qml.density_matrix(0), + qml.vn_entropy(0), + qml.mutual_info(0, 1), + qml.purity(0), + qml.classical_shadow(0), + ], +) +def test_parse_native_mid_circuit_measurements_unsupported_meas(measurement): + circuit = qml.tape.QuantumScript([qml.RX(1, 0)], [measurement]) + with pytest.raises(TypeError, match="Native mid-circuit measurement mode does not support"): + parse_native_mid_circuit_measurements(circuit, None, None) + + +def test_postselection_error_with_wrong_device(): + """Test that an error is raised when a device does not support native execution.""" + dev = qml.device("default.mixed", wires=2) + + with pytest.raises(TypeError, match="does not support mid-circuit measurements natively"): + + @qml.dynamic_one_shot + @qml.qnode(dev) + def _(): + qml.measure(0, postselect=1) + return qml.probs(wires=[0]) + + +def test_unsupported_measurements(): + """Test that using unsupported measurements raises an error.""" + tape = qml.tape.QuantumScript([MidMeasureMP(0)], [qml.state()]) + + with pytest.raises( + TypeError, + match="Native mid-circuit measurement mode does not support StateMP measurements.", + ): + _, _ = qml.dynamic_one_shot(tape) + + +def test_unsupported_shots(): + """Test that using shots=None raises an error.""" + tape = qml.tape.QuantumScript([MidMeasureMP(0)], [qml.probs(wires=0)], shots=None) + + with pytest.raises( + qml.QuantumFunctionError, + match="dynamic_one_shot is only supported with finite shots.", + ): + _, _ = qml.dynamic_one_shot(tape) + + +@pytest.mark.parametrize("n_shots", range(1, 10)) +def test_len_tapes(n_shots): + """Test that the transform produces the correct number of tapes.""" + tape = qml.tape.QuantumScript([MidMeasureMP(0)], [qml.expval(qml.PauliZ(0))], shots=n_shots) + tapes, _ = qml.dynamic_one_shot(tape) + assert len(tapes) == n_shots + + +@pytest.mark.parametrize("n_batch", range(1, 4)) +@pytest.mark.parametrize("n_shots", range(1, 4)) +def test_len_tape_batched(n_batch, n_shots): + """Test that the transform produces the correct number of tapes with batches.""" + params = np.random.rand(n_batch) + tape = qml.tape.QuantumScript( + [qml.RX(params, 0), MidMeasureMP(0, postselect=1), qml.CNOT([0, 1])], + [qml.expval(qml.PauliZ(0))], + shots=n_shots, + ) + tapes, _ = qml.dynamic_one_shot(tape) + assert len(tapes) == n_shots * n_batch + + +@pytest.mark.parametrize( + "measure, aux_measure, n_meas", + ( + (qml.counts, CountsMP, 1), + (qml.expval, ExpectationMP, 1), + (qml.probs, ProbabilityMP, 1), + (qml.sample, SampleMP, 1), + (qml.var, SampleMP, 1), + ), +) +def test_len_measurements_obs(measure, aux_measure, n_meas): + """Test that the transform produces the correct number of measurements in tapes measuring observables.""" + n_shots = 10 + n_mcms = 1 + tape = qml.tape.QuantumScript( + [qml.Hadamard(0)] + [MidMeasureMP(0)] * n_mcms, [measure(op=qml.PauliZ(0))], shots=n_shots + ) + tapes, _ = qml.dynamic_one_shot(tape) + assert len(tapes) == n_shots + aux_tape = tapes[0] + assert len(aux_tape.measurements) == n_meas + n_mcms + assert isinstance(aux_tape.measurements[0], aux_measure) + assert all(isinstance(m, SampleMP) for m in aux_tape.measurements[1:]) + + +@pytest.mark.parametrize( + "measure, aux_measure, n_meas", + ( + (qml.counts, SampleMP, 0), + (qml.expval, SampleMP, 0), + (qml.probs, SampleMP, 0), + (qml.sample, SampleMP, 0), + (qml.var, SampleMP, 0), + ), +) +def test_len_measurements_mcms(measure, aux_measure, n_meas): + """Test that the transform produces the correct number of measurements in tapes measuring MCMs.""" + n_shots = 10 + n_mcms = 1 + tape = qml.tape.QuantumScript( + [qml.Hadamard(0)] + [MidMeasureMP(0)] * n_mcms, + [measure(op=MeasurementValue([MidMeasureMP(0)], lambda x: x))], + shots=n_shots, + ) + tapes, _ = qml.dynamic_one_shot(tape) + assert len(tapes) == n_shots + aux_tape = tapes[0] + assert len(aux_tape.measurements) == n_meas + n_mcms + assert isinstance(aux_tape.measurements[0], aux_measure) + assert all(isinstance(m, SampleMP) for m in aux_tape.measurements[1:])