diff --git a/Makefile b/Makefile index 0a7573ec0b1..5f7186e4549 100644 --- a/Makefile +++ b/Makefile @@ -70,10 +70,10 @@ coverage: .PHONY:format format: ifdef check - isort --py 311 --profile black -l 100 -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests --check + isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests --check black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests --check else - isort --py 311 --profile black -l 100 -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests + isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests endif diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 65578cf7017..abef2086ae7 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -79,6 +79,9 @@

Mid-circuit measurements and dynamic circuits

+* The `dynamic_one_shot` transform is made compatible with the Catalyst compiler. + [(#5766)](https://github.com/PennyLaneAI/pennylane/pull/5766) + * Rationalize MCM tests, removing most end-to-end tests from the native MCM test file, but keeping one that validates multiple mid-circuit measurements with any allowed return and interface end-to-end tests. diff --git a/pennylane/measurements/counts.py b/pennylane/measurements/counts.py index 807be11d2e8..6552504cf03 100644 --- a/pennylane/measurements/counts.py +++ b/pennylane/measurements/counts.py @@ -317,7 +317,7 @@ def circuit(x): # remove nans mask = qml.math.isnan(samples) num_wires = shape[-1] - if np.any(mask): + if qml.math.any(mask): mask = np.logical_not(np.any(mask, axis=tuple(range(1, samples.ndim)))) samples = samples[mask, ...] diff --git a/pennylane/tape/tape.py b/pennylane/tape/tape.py index 9e7f3ff9cb5..59f21dae5dd 100644 --- a/pennylane/tape/tape.py +++ b/pennylane/tape/tape.py @@ -19,13 +19,7 @@ from threading import RLock import pennylane as qml -from pennylane.measurements import ( - CountsMP, - MeasurementProcess, - MidMeasureMP, - ProbabilityMP, - SampleMP, -) +from pennylane.measurements import CountsMP, MeasurementProcess, ProbabilityMP, SampleMP from pennylane.operation import DecompositionUndefinedError, Operator, StatePrepBase from pennylane.pytrees import register_pytree from pennylane.queuing import AnnotatedQueue, QueuingManager, process_queue @@ -51,7 +45,7 @@ def _validate_computational_basis_sampling(tape): qubit-wise commutativity relation.""" measurements = tape.measurements n_meas = len(measurements) - n_mcms = sum(isinstance(op, MidMeasureMP) for op in tape.operations) + n_mcms = sum(qml.transforms.is_mcm(op) for op in tape.operations) non_comp_basis_sampling_obs = [] comp_basis_sampling_obs = [] comp_basis_indices = [] @@ -68,10 +62,10 @@ def _validate_computational_basis_sampling(tape): for idx, (cb_obs, global_idx) in enumerate( zip(comp_basis_sampling_obs, comp_basis_indices) ): - if cb_obs.wires == empty_wires: - all_wires = qml.wires.Wires.all_wires([m.wires for m in measurements]) - break if global_idx < n_meas - n_mcms: + if cb_obs.wires == empty_wires: + all_wires = qml.wires.Wires.all_wires([m.wires for m in measurements]) + break all_wires.append(cb_obs.wires) if idx == len(comp_basis_sampling_obs) - 1: all_wires = qml.wires.Wires.all_wires(all_wires) diff --git a/pennylane/transforms/__init__.py b/pennylane/transforms/__init__.py index 7df898f2ee5..f69d978040b 100644 --- a/pennylane/transforms/__init__.py +++ b/pennylane/transforms/__init__.py @@ -288,7 +288,7 @@ def circuit(x, y): from .decompositions import clifford_t_decomposition from .defer_measurements import defer_measurements -from .dynamic_one_shot import dynamic_one_shot +from .dynamic_one_shot import dynamic_one_shot, is_mcm from .sign_expand import sign_expand from .hamiltonian_expand import hamiltonian_expand, sum_expand from .split_non_commuting import split_non_commuting diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 831e30031af..b05c1d583b0 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -40,6 +40,12 @@ fill_in_value = np.iinfo(np.int32).min +def is_mcm(operation): + """Returns True if the operation is a mid-circuit measurement and False otherwise.""" + mcm = isinstance(operation, MidMeasureMP) + return mcm or "MidCircuitMeasure" in str(type(operation)) + + def null_postprocessing(results): """A postprocessing function returned by a transform that only converts the batch of results into a result for a single ``QuantumTape``. @@ -118,6 +124,9 @@ def func(x, y): aux_tapes = [init_auxiliary_tape(t) for t in tapes] + def reshape_data(array): + return qml.math.squeeze(qml.math.vstack(array)) + def processing_fn(results, has_partitioned_shots=None, batched_results=None): if batched_results is None and batch_size is not None: # If broadcasting, recursively process the results for each batch. For each batch @@ -141,6 +150,14 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None): return tuple(final_results) if not tape.shots.has_partitioned_shots: results = results[0] + + is_scalar = not isinstance(results[0], Sequence) + if is_scalar: + results = [reshape_data(tuple(results))] + else: + results = [ + reshape_data(tuple(res[i] for res in results)) for i, _ in enumerate(results[0]) + ] return parse_native_mid_circuit_measurements(tape, aux_tapes, results) return aux_tapes, processing_fn @@ -170,12 +187,6 @@ def _dynamic_one_shot_qnode(self, qnode, targs, tkwargs): return self.default_qnode_transform(qnode, targs, tkwargs) -def is_mcm(operation): - """Returns True if the operation is a mid-circuit measurement and False otherwise.""" - mcm = isinstance(operation, MidMeasureMP) - return mcm or "MidCircuitMeasure" in str(type(operation)) - - def init_auxiliary_tape(circuit: qml.tape.QuantumScript): """Creates an auxiliary circuit to perform one-shot mid-circuit measurement calculations. @@ -195,10 +206,11 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): new_measurements.append(SampleMP(obs=m.obs)) else: new_measurements.append(m) - for op in circuit: - if is_mcm(op): + for op in circuit.operations: + if "MidCircuitMeasure" in str(type(op)): # pragma: no cover + new_measurements.append(qml.sample(op.out_classical_tracers[0])) + elif isinstance(op, MidMeasureMP): new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) - return qml.tape.QuantumScript( circuit.operations, new_measurements, @@ -207,14 +219,15 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): ) +# pylint: disable=too-many-branches,too-many-statements def parse_native_mid_circuit_measurements( circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results: TensorLike ): """Combines, gathers and normalizes the results of native mid-circuit measurement runs. Args: - circuit (QuantumTape): Initial ``QuantumScript`` - aux_tapes (List[QuantumTape]): List of auxilary ``QuantumScript`` objects + circuit (QuantumTape): The original ``QuantumScript`` + aux_tapes (List[QuantumTape]): List of auxiliary ``QuantumScript`` objects results (TensorLike): Array of measurement results Returns: @@ -230,21 +243,12 @@ def measurement_with_no_shots(measurement): interface = qml.math.get_deep_interface(circuit.data) interface = "numpy" if interface == "builtins" else interface + active_qjit = qml.compiler.active() all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)] n_mcms = len(all_mcms) - post_process_tape = qml.tape.QuantumScript( - aux_tapes[0].operations, - aux_tapes[0].measurements[0:-n_mcms], - shots=aux_tapes[0].shots, - trainable_params=aux_tapes[0].trainable_params, - ) - single_measurement = ( - len(post_process_tape.measurements) == 0 and len(aux_tapes[0].measurements) == 1 - ) - mcm_samples = qml.math.array( - [[res] if single_measurement else res[-n_mcms::] for res in results], like=interface - ) + mcm_samples = qml.math.hstack(tuple(res.reshape((-1, 1)) for res in results[-n_mcms:])) + mcm_samples = qml.math.array(mcm_samples, like=interface) # Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1 has_postselect = qml.math.array( [[int(op.postselect is not None) for op in all_mcms]], like=interface @@ -257,7 +261,6 @@ def measurement_with_no_shots(measurement): mid_meas = [op for op in circuit.operations if is_mcm(op)] mcm_samples = [mcm_samples[:, i : i + 1] for i in range(n_mcms)] mcm_samples = dict((k, v) for k, v in zip(mid_meas, mcm_samples)) - normalized_meas = [] m_count = 0 for m in circuit.measurements: @@ -267,18 +270,31 @@ def measurement_with_no_shots(measurement): ) if interface != "jax" and m.mv and not has_valid: meas = measurement_with_no_shots(m) + elif m.mv and active_qjit: + meas = gather_mcm_qjit(m, mcm_samples, is_valid) # pragma: no cover elif m.mv: meas = gather_mcm(m, mcm_samples, is_valid) elif interface != "jax" and not has_valid: meas = measurement_with_no_shots(m) m_count += 1 else: - result = [res[m_count] for res in results] + result = results[m_count] if not isinstance(m, CountsMP): # We don't need to cast to arrays when using qml.counts. qml.math.array is not viable # as it assumes all elements of the input are of builtin python types and not belonging # to any particular interface - result = qml.math.stack(result, like=interface) + result = qml.math.array(result, like=interface) + if active_qjit: # pragma: no cover + # `result` contains (bases, counts) need to return (basis, sum(counts)) where `is_valid` + # Any row of `result[0]` contains basis, so we return `result[0][0]` + # We return the sum of counts (`result[1]`) weighting by `is_valid`, which is `0` for invalid samples + if isinstance(m, CountsMP): + normalized_meas.append( + (result[0][0], qml.math.sum(result[1] * is_valid.reshape((-1, 1)), axis=0)) + ) + m_count += 1 + continue + result = qml.math.squeeze(result) meas = gather_non_mcm(m, result, is_valid) m_count += 1 if isinstance(m, SampleMP): @@ -288,6 +304,40 @@ def measurement_with_no_shots(measurement): return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0] +def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover + """Process MCM measurements when the Catalyst compiler is active. + + Args: + measurement (MeasurementProcess): measurement + samples (dict): Mid-circuit measurement samples + is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at + each index specifies whether or not the respective sample is valid. + + Returns: + TensorLike: The combined measurement outcome + """ + found, meas = False, None + for k, meas in samples.items(): + if measurement.mv is k.out_classical_tracers[0]: + found = True + break + if not found: + raise LookupError("MCM not found") + meas = qml.math.squeeze(meas) + if isinstance(measurement, (CountsMP, ProbabilityMP)): + interface = qml.math.get_deep_interface(is_valid) + sum_valid = qml.math.sum(is_valid) + count_1 = qml.math.sum(meas * is_valid) + if isinstance(measurement, CountsMP): + return qml.math.array([0, 1], like=interface), qml.math.array( + [sum_valid - count_1, count_1], like=interface + ) + if isinstance(measurement, ProbabilityMP): + counts = qml.math.array([sum_valid - count_1, count_1], like=interface) + return counts / sum_valid + return gather_non_mcm(measurement, meas, is_valid) + + def gather_non_mcm(measurement, samples, is_valid): """Combines, gathers and normalizes several measurements with trivial measurement values. @@ -306,7 +356,8 @@ def gather_non_mcm(measurement, samples, is_valid): tmp.update( dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items()) ) - tmp = Counter({k: v for k, v in tmp.items() if v > 0}) + if not measurement.all_outcomes: + tmp = Counter({k: v for k, v in tmp.items() if v > 0}) return dict(sorted(tmp.items())) if isinstance(measurement, ExpectationMP): return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid) diff --git a/tests/conftest.py b/tests/conftest.py index 1b90d1f20f5..7f2eb29e46f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,7 @@ import contextlib import os import pathlib +import sys import numpy as np import pytest @@ -26,6 +27,8 @@ from pennylane.devices import DefaultGaussian from pennylane.operation import disable_new_opmath_cm, enable_new_opmath_cm +sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) + # defaults TOL = 1e-3 TF_TOL = 2e-2 @@ -206,6 +209,7 @@ def use_legacy_opmath(): yield cm +# pylint: disable=contextmanager-generator-missing-cleanup @pytest.fixture(scope="function") def use_new_opmath(): with enable_new_opmath_cm() as cm: diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index ef1522b166a..6f882cd332d 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for default qubit preprocessing.""" -from functools import reduce -from typing import Iterable, Sequence +from typing import Sequence +import mcm_utils import numpy as np import pytest @@ -31,112 +31,6 @@ def get_device(**kwargs): return qml.device("default.qubit", **kwargs) -def validate_counts(shots, results1, results2, batch_size=None): - """Compares two counts. - - If the results are ``Sequence``s, loop over entries. - - Fails if a key of ``results1`` is not found in ``results2``. - Passes if counts are too low, chosen as ``100``. - Otherwise, fails if counts differ by more than ``20`` plus 20 percent. - """ - if isinstance(shots, Sequence): - assert isinstance(results1, tuple) - assert isinstance(results2, tuple) - assert len(results1) == len(results2) == len(shots) - for s, r1, r2 in zip(shots, results1, results2): - validate_counts(s, r1, r2, batch_size=batch_size) - return - - if batch_size is not None: - assert isinstance(results1, Iterable) - assert isinstance(results2, Iterable) - assert len(results1) == len(results2) == batch_size - for r1, r2 in zip(results1, results2): - validate_counts(shots, r1, r2, batch_size=None) - return - - for key1, val1 in results1.items(): - val2 = results2[key1] - if abs(val1 + val2) > 100: - assert np.allclose(val1, val2, atol=20, rtol=0.2) - - -def validate_samples(shots, results1, results2, batch_size=None): - """Compares two samples. - - If the results are ``Sequence``s, loop over entries. - - Fails if the results do not have the same shape, within ``20`` entries plus 20 percent. - This is to handle cases when post-selection yields variable shapes. - Otherwise, fails if the sums of samples differ by more than ``20`` plus 20 percent. - """ - if isinstance(shots, Sequence): - assert isinstance(results1, tuple) - assert isinstance(results2, tuple) - assert len(results1) == len(results2) == len(shots) - for s, r1, r2 in zip(shots, results1, results2): - validate_samples(s, r1, r2, batch_size=batch_size) - return - - if batch_size is not None: - assert isinstance(results1, Iterable) - assert isinstance(results2, Iterable) - assert len(results1) == len(results2) == batch_size - for r1, r2 in zip(results1, results2): - validate_samples(shots, r1, r2, batch_size=None) - return - - sh1, sh2 = results1.shape[0], results2.shape[0] - assert np.allclose(sh1, sh2, atol=20, rtol=0.2) - assert results1.ndim == results2.ndim - if results2.ndim > 1: - assert results1.shape[1] == results2.shape[1] - np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2) - - -def validate_expval(shots, results1, results2, batch_size=None): - """Compares two expval, probs or var. - - If the results are ``Sequence``s, validate the average of items. - - If ``shots is None``, validate using ``np.allclose``'s default parameters. - Otherwise, fails if the results do not match within ``0.01`` plus 20 percent. - """ - if isinstance(shots, Sequence): - assert isinstance(results1, tuple) - assert isinstance(results2, tuple) - assert len(results1) == len(results2) == len(shots) - results1 = reduce(lambda x, y: x + y, results1) / len(results1) - results2 = reduce(lambda x, y: x + y, results2) / len(results2) - validate_expval(sum(shots), results1, results2, batch_size=batch_size) - return - - if shots is None: - assert np.allclose(results1, results2) - return - - if batch_size is not None: - assert len(results1) == len(results2) == batch_size - for r1, r2 in zip(results1, results2): - validate_expval(shots, r1, r2, batch_size=None) - - assert np.allclose(results1, results2, atol=0.01, rtol=0.2) - - -def validate_measurements(func, shots, results1, results2, batch_size=None): - """Calls the correct validation function based on measurement type.""" - if func is qml.counts: - validate_counts(shots, results1, results2, batch_size=batch_size) - return - - if func is qml.sample: - validate_samples(shots, results1, results2, batch_size=batch_size) - return - - validate_expval(shots, results1, results2, batch_size=batch_size) - - def test_apply_mid_measure(): """Test that apply_mid_measure raises if applied to a batched state.""" with pytest.raises(ValueError, match="MidMeasureMP cannot be applied to batched states."): @@ -263,7 +157,7 @@ def func(x, y, z): func2 = qml.defer_measurements(func) results2 = func2(*params) - validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("postselect", [None, 0, 1]) @@ -297,7 +191,7 @@ def func(x, y, z): for measure_f, r1, r2 in zip( [qml.counts, qml.expval, qml.probs, qml.sample, qml.var], results1, results2 ): - validate_measurements(measure_f, shots, r1, r2) + mcm_utils.validate_measurements(measure_f, shots, r1, r2) @pytest.mark.parametrize( @@ -355,7 +249,7 @@ def func(x): results1 = func1(param) results2 = func2(param) - validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize( @@ -422,7 +316,7 @@ def func(x, y): results1 = func1(*param) results2 = func2(*param) - validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) grad1 = qml.grad(func)(*param) grad2 = qml.grad(func2)(*param) @@ -453,7 +347,7 @@ def func(x, y): results1 = func1(*param) results2 = func2(*param) - validate_measurements(measure_fn, shots, results1, results2, batch_size=2) + mcm_utils.validate_measurements(measure_fn, shots, results1, results2, batch_size=2) if measure_fn is qml.sample and postselect is None: for i in range(2): # batch_size @@ -509,7 +403,7 @@ def func(x, y): results1 = func1(*param) results2 = func2(*param) - validate_measurements(qml.sample, shots, results1, results2, batch_size=None) + mcm_utils.validate_measurements(qml.sample, shots, results1, results2, batch_size=None) evals = obs.eigvals() for eig in evals: @@ -638,4 +532,4 @@ def func(x): results1 = func1(param) results2 = func2(param) - validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) diff --git a/tests/helpers/mcm_utils.py b/tests/helpers/mcm_utils.py new file mode 100644 index 00000000000..fab28c2bfca --- /dev/null +++ b/tests/helpers/mcm_utils.py @@ -0,0 +1,128 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Pytest helper functions are defined in this module. +""" +from functools import reduce +from typing import Iterable, Sequence + +import numpy as np + +import pennylane as qml + + +def validate_counts(shots, results1, results2, batch_size=None): + """Compares two counts. + + If the results are ``Sequence``s, loop over entries. + + Fails if a key of ``results1`` is not found in ``results2``. + Passes if counts are too low, chosen as ``100``. + Otherwise, fails if counts differ by more than ``20`` plus 20 percent. + """ + if isinstance(shots, Sequence): + assert isinstance(results1, tuple) + assert isinstance(results2, tuple) + assert len(results1) == len(results2) == len(shots) + for s, r1, r2 in zip(shots, results1, results2): + validate_counts(s, r1, r2, batch_size=batch_size) + return + + if batch_size is not None: + assert isinstance(results1, Iterable) + assert isinstance(results2, Iterable) + assert len(results1) == len(results2) == batch_size + for r1, r2 in zip(results1, results2): + validate_counts(shots, r1, r2, batch_size=None) + return + + for key1, val1 in results1.items(): + val2 = results2[key1] + if abs(val1 + val2) > 100: + assert np.allclose(val1, val2, atol=20, rtol=0.2) + + +def validate_samples(shots, results1, results2, batch_size=None): + """Compares two samples. + + If the results are ``Sequence``s, loop over entries. + + Fails if the results do not have the same shape, within ``20`` entries plus 20 percent. + This is to handle cases when post-selection yields variable shapes. + Otherwise, fails if the sums of samples differ by more than ``20`` plus 20 percent. + """ + if isinstance(shots, Sequence): + assert isinstance(results1, tuple) + assert isinstance(results2, tuple) + assert len(results1) == len(results2) == len(shots) + for s, r1, r2 in zip(shots, results1, results2): + validate_samples(s, r1, r2, batch_size=batch_size) + return + + if batch_size is not None: + assert isinstance(results1, Iterable) + assert isinstance(results2, Iterable) + assert len(results1) == len(results2) == batch_size + for r1, r2 in zip(results1, results2): + validate_samples(shots, r1, r2, batch_size=None) + return + + sh1, sh2 = results1.shape[0], results2.shape[0] + assert np.allclose(sh1, sh2, atol=20, rtol=0.2) + assert results1.ndim == results2.ndim + if results2.ndim > 1: + assert results1.shape[1] == results2.shape[1] + np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2) + + +def validate_expval(shots, results1, results2, batch_size=None): + """Compares two expval, probs or var. + + If the results are ``Sequence``s, validate the average of items. + + If ``shots is None``, validate using ``np.allclose``'s default parameters. + Otherwise, fails if the results do not match within ``0.01`` plus 20 percent. + """ + if isinstance(shots, Sequence): + assert isinstance(results1, tuple) + assert isinstance(results2, tuple) + assert len(results1) == len(results2) == len(shots) + results1 = reduce(lambda x, y: x + y, results1) / len(results1) + results2 = reduce(lambda x, y: x + y, results2) / len(results2) + validate_expval(sum(shots), results1, results2, batch_size=batch_size) + return + + if shots is None: + assert np.allclose(results1, results2) + return + + if batch_size is not None: + assert len(results1) == len(results2) == batch_size + for r1, r2 in zip(results1, results2): + validate_expval(shots, r1, r2, batch_size=None) + + assert np.allclose(results1, results2, atol=0.01, rtol=0.2) + + +def validate_measurements(func, shots, results1, results2, batch_size=None): + """Calls the correct validation function based on measurement type.""" + if func is qml.counts: + validate_counts(shots, results1, results2, batch_size=batch_size) + return + + if func is qml.sample: + validate_samples(shots, results1, results2, batch_size=batch_size) + return + + validate_expval(shots, results1, results2, batch_size=batch_size) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 370b4776fab..0b407bba814 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -17,11 +17,14 @@ # pylint: disable=import-outside-toplevel from unittest.mock import patch +import mcm_utils +import numpy as np import pytest import pennylane as qml from pennylane import numpy as np from pennylane.compiler.compiler import CompileError +from pennylane.transforms.dynamic_one_shot import fill_in_value catalyst = pytest.importorskip("catalyst") jax = pytest.importorskip("jax") @@ -737,3 +740,74 @@ def circuit(x): assert circuit(0.0) == 0 assert circuit(jnp.pi) == 1 + + +class TestCatalystMCMs: + """Test dynamic_one_shot with Catalyst.""" + + @pytest.mark.xfail(reason="requires simultaneous catalyst pr") + @pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs]) + @pytest.mark.parametrize("meas_obj", [qml.PauliZ(0), [0], "mcm"]) + # pylint: disable=too-many-arguments + def test_dynamic_one_shot_simple(self, measure_f, meas_obj): + """Tests that Catalyst yields the same results as PennyLane's DefaultQubit for a simple + circuit with a mid-circuit measurement.""" + if measure_f in (qml.counts, qml.probs, qml.sample) and ( + not isinstance(meas_obj, list) and not meas_obj == "mcm" + ): + pytest.skip("Can't use observables with counts, probs or sample") + + if measure_f in (qml.var, qml.expval) and (isinstance(meas_obj, list)): + pytest.skip("Can't use wires/mcm lists with var or expval") + + if measure_f == qml.var and (not isinstance(meas_obj, list) and not meas_obj == "mcm"): + pytest.xfail("isa") + shots = 8000 + + dq = qml.device("default.qubit", shots=shots, seed=8237945) + + @qml.defer_measurements + @qml.qnode(dq) + def ref_func(x, y): + qml.RX(x, wires=0) + m0 = qml.measure(0) + qml.cond(m0, qml.RY)(y, wires=1) + + meas_key = "wires" if isinstance(meas_obj, list) else "op" + meas_value = m0 if isinstance(meas_obj, str) else meas_obj + kwargs = {meas_key: meas_value} + if measure_f == qml.counts: + kwargs["all_outcomes"] = True + return measure_f(**kwargs) + + dev = qml.device("lightning.qubit", wires=2, shots=shots) + + @qml.qjit + @catalyst.dynamic_one_shot + @qml.qnode(dev) + def func(x, y): + qml.RX(x, wires=0) + m0 = catalyst.measure(0) + + @catalyst.cond(m0 == 1) + def ansatz(): + qml.RY(y, wires=1) + + ansatz() + + meas_key = "wires" if isinstance(meas_obj, list) else "op" + meas_value = m0 if isinstance(meas_obj, str) else meas_obj + kwargs = {meas_key: meas_value} + return measure_f(**kwargs) + + params = jnp.pi / 4 * jnp.ones(2) + results0 = ref_func(*params) + results1 = func(*params) + if measure_f == qml.counts and isinstance(meas_obj, list): + results1 = { + format(int(state), f"0{len(meas_obj)}b"): count for state, count in zip(*results1) + } + if measure_f == qml.sample: + results0 = results0[results0 != fill_in_value] + results1 = results1[results1 != fill_in_value] + mcm_utils.validate_measurements(measure_f, shots, results1, results0) diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index 6c908840359..16ac830cd11 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -45,7 +45,7 @@ def test_parse_native_mid_circuit_measurements_unsupported_meas(measurement): circuit = qml.tape.QuantumScript([qml.RX(1.0, 0)], [measurement]) with pytest.raises(TypeError, match="Native mid-circuit measurement mode does not support"): - parse_native_mid_circuit_measurements(circuit, [circuit], [[]]) + parse_native_mid_circuit_measurements(circuit, [circuit], [np.empty((0,))]) def test_postselection_error_with_wrong_device():