diff --git a/Makefile b/Makefile
index 0a7573ec0b1..5f7186e4549 100644
--- a/Makefile
+++ b/Makefile
@@ -70,10 +70,10 @@ coverage:
.PHONY:format
format:
ifdef check
- isort --py 311 --profile black -l 100 -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests --check
+ isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests --check
black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests --check
else
- isort --py 311 --profile black -l 100 -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests
+ isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests
black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests
endif
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 65578cf7017..abef2086ae7 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -79,6 +79,9 @@
Mid-circuit measurements and dynamic circuits
+* The `dynamic_one_shot` transform is made compatible with the Catalyst compiler.
+ [(#5766)](https://github.com/PennyLaneAI/pennylane/pull/5766)
+
* Rationalize MCM tests, removing most end-to-end tests from the native MCM test file,
but keeping one that validates multiple mid-circuit measurements with any allowed return
and interface end-to-end tests.
diff --git a/pennylane/measurements/counts.py b/pennylane/measurements/counts.py
index 807be11d2e8..6552504cf03 100644
--- a/pennylane/measurements/counts.py
+++ b/pennylane/measurements/counts.py
@@ -317,7 +317,7 @@ def circuit(x):
# remove nans
mask = qml.math.isnan(samples)
num_wires = shape[-1]
- if np.any(mask):
+ if qml.math.any(mask):
mask = np.logical_not(np.any(mask, axis=tuple(range(1, samples.ndim))))
samples = samples[mask, ...]
diff --git a/pennylane/tape/tape.py b/pennylane/tape/tape.py
index 9e7f3ff9cb5..59f21dae5dd 100644
--- a/pennylane/tape/tape.py
+++ b/pennylane/tape/tape.py
@@ -19,13 +19,7 @@
from threading import RLock
import pennylane as qml
-from pennylane.measurements import (
- CountsMP,
- MeasurementProcess,
- MidMeasureMP,
- ProbabilityMP,
- SampleMP,
-)
+from pennylane.measurements import CountsMP, MeasurementProcess, ProbabilityMP, SampleMP
from pennylane.operation import DecompositionUndefinedError, Operator, StatePrepBase
from pennylane.pytrees import register_pytree
from pennylane.queuing import AnnotatedQueue, QueuingManager, process_queue
@@ -51,7 +45,7 @@ def _validate_computational_basis_sampling(tape):
qubit-wise commutativity relation."""
measurements = tape.measurements
n_meas = len(measurements)
- n_mcms = sum(isinstance(op, MidMeasureMP) for op in tape.operations)
+ n_mcms = sum(qml.transforms.is_mcm(op) for op in tape.operations)
non_comp_basis_sampling_obs = []
comp_basis_sampling_obs = []
comp_basis_indices = []
@@ -68,10 +62,10 @@ def _validate_computational_basis_sampling(tape):
for idx, (cb_obs, global_idx) in enumerate(
zip(comp_basis_sampling_obs, comp_basis_indices)
):
- if cb_obs.wires == empty_wires:
- all_wires = qml.wires.Wires.all_wires([m.wires for m in measurements])
- break
if global_idx < n_meas - n_mcms:
+ if cb_obs.wires == empty_wires:
+ all_wires = qml.wires.Wires.all_wires([m.wires for m in measurements])
+ break
all_wires.append(cb_obs.wires)
if idx == len(comp_basis_sampling_obs) - 1:
all_wires = qml.wires.Wires.all_wires(all_wires)
diff --git a/pennylane/transforms/__init__.py b/pennylane/transforms/__init__.py
index 7df898f2ee5..f69d978040b 100644
--- a/pennylane/transforms/__init__.py
+++ b/pennylane/transforms/__init__.py
@@ -288,7 +288,7 @@ def circuit(x, y):
from .decompositions import clifford_t_decomposition
from .defer_measurements import defer_measurements
-from .dynamic_one_shot import dynamic_one_shot
+from .dynamic_one_shot import dynamic_one_shot, is_mcm
from .sign_expand import sign_expand
from .hamiltonian_expand import hamiltonian_expand, sum_expand
from .split_non_commuting import split_non_commuting
diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py
index 831e30031af..b05c1d583b0 100644
--- a/pennylane/transforms/dynamic_one_shot.py
+++ b/pennylane/transforms/dynamic_one_shot.py
@@ -40,6 +40,12 @@
fill_in_value = np.iinfo(np.int32).min
+def is_mcm(operation):
+ """Returns True if the operation is a mid-circuit measurement and False otherwise."""
+ mcm = isinstance(operation, MidMeasureMP)
+ return mcm or "MidCircuitMeasure" in str(type(operation))
+
+
def null_postprocessing(results):
"""A postprocessing function returned by a transform that only converts the batch of results
into a result for a single ``QuantumTape``.
@@ -118,6 +124,9 @@ def func(x, y):
aux_tapes = [init_auxiliary_tape(t) for t in tapes]
+ def reshape_data(array):
+ return qml.math.squeeze(qml.math.vstack(array))
+
def processing_fn(results, has_partitioned_shots=None, batched_results=None):
if batched_results is None and batch_size is not None:
# If broadcasting, recursively process the results for each batch. For each batch
@@ -141,6 +150,14 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None):
return tuple(final_results)
if not tape.shots.has_partitioned_shots:
results = results[0]
+
+ is_scalar = not isinstance(results[0], Sequence)
+ if is_scalar:
+ results = [reshape_data(tuple(results))]
+ else:
+ results = [
+ reshape_data(tuple(res[i] for res in results)) for i, _ in enumerate(results[0])
+ ]
return parse_native_mid_circuit_measurements(tape, aux_tapes, results)
return aux_tapes, processing_fn
@@ -170,12 +187,6 @@ def _dynamic_one_shot_qnode(self, qnode, targs, tkwargs):
return self.default_qnode_transform(qnode, targs, tkwargs)
-def is_mcm(operation):
- """Returns True if the operation is a mid-circuit measurement and False otherwise."""
- mcm = isinstance(operation, MidMeasureMP)
- return mcm or "MidCircuitMeasure" in str(type(operation))
-
-
def init_auxiliary_tape(circuit: qml.tape.QuantumScript):
"""Creates an auxiliary circuit to perform one-shot mid-circuit measurement calculations.
@@ -195,10 +206,11 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript):
new_measurements.append(SampleMP(obs=m.obs))
else:
new_measurements.append(m)
- for op in circuit:
- if is_mcm(op):
+ for op in circuit.operations:
+ if "MidCircuitMeasure" in str(type(op)): # pragma: no cover
+ new_measurements.append(qml.sample(op.out_classical_tracers[0]))
+ elif isinstance(op, MidMeasureMP):
new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res)))
-
return qml.tape.QuantumScript(
circuit.operations,
new_measurements,
@@ -207,14 +219,15 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript):
)
+# pylint: disable=too-many-branches,too-many-statements
def parse_native_mid_circuit_measurements(
circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results: TensorLike
):
"""Combines, gathers and normalizes the results of native mid-circuit measurement runs.
Args:
- circuit (QuantumTape): Initial ``QuantumScript``
- aux_tapes (List[QuantumTape]): List of auxilary ``QuantumScript`` objects
+ circuit (QuantumTape): The original ``QuantumScript``
+ aux_tapes (List[QuantumTape]): List of auxiliary ``QuantumScript`` objects
results (TensorLike): Array of measurement results
Returns:
@@ -230,21 +243,12 @@ def measurement_with_no_shots(measurement):
interface = qml.math.get_deep_interface(circuit.data)
interface = "numpy" if interface == "builtins" else interface
+ active_qjit = qml.compiler.active()
all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)]
n_mcms = len(all_mcms)
- post_process_tape = qml.tape.QuantumScript(
- aux_tapes[0].operations,
- aux_tapes[0].measurements[0:-n_mcms],
- shots=aux_tapes[0].shots,
- trainable_params=aux_tapes[0].trainable_params,
- )
- single_measurement = (
- len(post_process_tape.measurements) == 0 and len(aux_tapes[0].measurements) == 1
- )
- mcm_samples = qml.math.array(
- [[res] if single_measurement else res[-n_mcms::] for res in results], like=interface
- )
+ mcm_samples = qml.math.hstack(tuple(res.reshape((-1, 1)) for res in results[-n_mcms:]))
+ mcm_samples = qml.math.array(mcm_samples, like=interface)
# Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1
has_postselect = qml.math.array(
[[int(op.postselect is not None) for op in all_mcms]], like=interface
@@ -257,7 +261,6 @@ def measurement_with_no_shots(measurement):
mid_meas = [op for op in circuit.operations if is_mcm(op)]
mcm_samples = [mcm_samples[:, i : i + 1] for i in range(n_mcms)]
mcm_samples = dict((k, v) for k, v in zip(mid_meas, mcm_samples))
-
normalized_meas = []
m_count = 0
for m in circuit.measurements:
@@ -267,18 +270,31 @@ def measurement_with_no_shots(measurement):
)
if interface != "jax" and m.mv and not has_valid:
meas = measurement_with_no_shots(m)
+ elif m.mv and active_qjit:
+ meas = gather_mcm_qjit(m, mcm_samples, is_valid) # pragma: no cover
elif m.mv:
meas = gather_mcm(m, mcm_samples, is_valid)
elif interface != "jax" and not has_valid:
meas = measurement_with_no_shots(m)
m_count += 1
else:
- result = [res[m_count] for res in results]
+ result = results[m_count]
if not isinstance(m, CountsMP):
# We don't need to cast to arrays when using qml.counts. qml.math.array is not viable
# as it assumes all elements of the input are of builtin python types and not belonging
# to any particular interface
- result = qml.math.stack(result, like=interface)
+ result = qml.math.array(result, like=interface)
+ if active_qjit: # pragma: no cover
+ # `result` contains (bases, counts) need to return (basis, sum(counts)) where `is_valid`
+ # Any row of `result[0]` contains basis, so we return `result[0][0]`
+ # We return the sum of counts (`result[1]`) weighting by `is_valid`, which is `0` for invalid samples
+ if isinstance(m, CountsMP):
+ normalized_meas.append(
+ (result[0][0], qml.math.sum(result[1] * is_valid.reshape((-1, 1)), axis=0))
+ )
+ m_count += 1
+ continue
+ result = qml.math.squeeze(result)
meas = gather_non_mcm(m, result, is_valid)
m_count += 1
if isinstance(m, SampleMP):
@@ -288,6 +304,40 @@ def measurement_with_no_shots(measurement):
return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0]
+def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover
+ """Process MCM measurements when the Catalyst compiler is active.
+
+ Args:
+ measurement (MeasurementProcess): measurement
+ samples (dict): Mid-circuit measurement samples
+ is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at
+ each index specifies whether or not the respective sample is valid.
+
+ Returns:
+ TensorLike: The combined measurement outcome
+ """
+ found, meas = False, None
+ for k, meas in samples.items():
+ if measurement.mv is k.out_classical_tracers[0]:
+ found = True
+ break
+ if not found:
+ raise LookupError("MCM not found")
+ meas = qml.math.squeeze(meas)
+ if isinstance(measurement, (CountsMP, ProbabilityMP)):
+ interface = qml.math.get_deep_interface(is_valid)
+ sum_valid = qml.math.sum(is_valid)
+ count_1 = qml.math.sum(meas * is_valid)
+ if isinstance(measurement, CountsMP):
+ return qml.math.array([0, 1], like=interface), qml.math.array(
+ [sum_valid - count_1, count_1], like=interface
+ )
+ if isinstance(measurement, ProbabilityMP):
+ counts = qml.math.array([sum_valid - count_1, count_1], like=interface)
+ return counts / sum_valid
+ return gather_non_mcm(measurement, meas, is_valid)
+
+
def gather_non_mcm(measurement, samples, is_valid):
"""Combines, gathers and normalizes several measurements with trivial measurement values.
@@ -306,7 +356,8 @@ def gather_non_mcm(measurement, samples, is_valid):
tmp.update(
dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items())
)
- tmp = Counter({k: v for k, v in tmp.items() if v > 0})
+ if not measurement.all_outcomes:
+ tmp = Counter({k: v for k, v in tmp.items() if v > 0})
return dict(sorted(tmp.items()))
if isinstance(measurement, ExpectationMP):
return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid)
diff --git a/tests/conftest.py b/tests/conftest.py
index 1b90d1f20f5..7f2eb29e46f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -18,6 +18,7 @@
import contextlib
import os
import pathlib
+import sys
import numpy as np
import pytest
@@ -26,6 +27,8 @@
from pennylane.devices import DefaultGaussian
from pennylane.operation import disable_new_opmath_cm, enable_new_opmath_cm
+sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))
+
# defaults
TOL = 1e-3
TF_TOL = 2e-2
@@ -206,6 +209,7 @@ def use_legacy_opmath():
yield cm
+# pylint: disable=contextmanager-generator-missing-cleanup
@pytest.fixture(scope="function")
def use_new_opmath():
with enable_new_opmath_cm() as cm:
diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
index ef1522b166a..6f882cd332d 100644
--- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py
+++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for default qubit preprocessing."""
-from functools import reduce
-from typing import Iterable, Sequence
+from typing import Sequence
+import mcm_utils
import numpy as np
import pytest
@@ -31,112 +31,6 @@ def get_device(**kwargs):
return qml.device("default.qubit", **kwargs)
-def validate_counts(shots, results1, results2, batch_size=None):
- """Compares two counts.
-
- If the results are ``Sequence``s, loop over entries.
-
- Fails if a key of ``results1`` is not found in ``results2``.
- Passes if counts are too low, chosen as ``100``.
- Otherwise, fails if counts differ by more than ``20`` plus 20 percent.
- """
- if isinstance(shots, Sequence):
- assert isinstance(results1, tuple)
- assert isinstance(results2, tuple)
- assert len(results1) == len(results2) == len(shots)
- for s, r1, r2 in zip(shots, results1, results2):
- validate_counts(s, r1, r2, batch_size=batch_size)
- return
-
- if batch_size is not None:
- assert isinstance(results1, Iterable)
- assert isinstance(results2, Iterable)
- assert len(results1) == len(results2) == batch_size
- for r1, r2 in zip(results1, results2):
- validate_counts(shots, r1, r2, batch_size=None)
- return
-
- for key1, val1 in results1.items():
- val2 = results2[key1]
- if abs(val1 + val2) > 100:
- assert np.allclose(val1, val2, atol=20, rtol=0.2)
-
-
-def validate_samples(shots, results1, results2, batch_size=None):
- """Compares two samples.
-
- If the results are ``Sequence``s, loop over entries.
-
- Fails if the results do not have the same shape, within ``20`` entries plus 20 percent.
- This is to handle cases when post-selection yields variable shapes.
- Otherwise, fails if the sums of samples differ by more than ``20`` plus 20 percent.
- """
- if isinstance(shots, Sequence):
- assert isinstance(results1, tuple)
- assert isinstance(results2, tuple)
- assert len(results1) == len(results2) == len(shots)
- for s, r1, r2 in zip(shots, results1, results2):
- validate_samples(s, r1, r2, batch_size=batch_size)
- return
-
- if batch_size is not None:
- assert isinstance(results1, Iterable)
- assert isinstance(results2, Iterable)
- assert len(results1) == len(results2) == batch_size
- for r1, r2 in zip(results1, results2):
- validate_samples(shots, r1, r2, batch_size=None)
- return
-
- sh1, sh2 = results1.shape[0], results2.shape[0]
- assert np.allclose(sh1, sh2, atol=20, rtol=0.2)
- assert results1.ndim == results2.ndim
- if results2.ndim > 1:
- assert results1.shape[1] == results2.shape[1]
- np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2)
-
-
-def validate_expval(shots, results1, results2, batch_size=None):
- """Compares two expval, probs or var.
-
- If the results are ``Sequence``s, validate the average of items.
-
- If ``shots is None``, validate using ``np.allclose``'s default parameters.
- Otherwise, fails if the results do not match within ``0.01`` plus 20 percent.
- """
- if isinstance(shots, Sequence):
- assert isinstance(results1, tuple)
- assert isinstance(results2, tuple)
- assert len(results1) == len(results2) == len(shots)
- results1 = reduce(lambda x, y: x + y, results1) / len(results1)
- results2 = reduce(lambda x, y: x + y, results2) / len(results2)
- validate_expval(sum(shots), results1, results2, batch_size=batch_size)
- return
-
- if shots is None:
- assert np.allclose(results1, results2)
- return
-
- if batch_size is not None:
- assert len(results1) == len(results2) == batch_size
- for r1, r2 in zip(results1, results2):
- validate_expval(shots, r1, r2, batch_size=None)
-
- assert np.allclose(results1, results2, atol=0.01, rtol=0.2)
-
-
-def validate_measurements(func, shots, results1, results2, batch_size=None):
- """Calls the correct validation function based on measurement type."""
- if func is qml.counts:
- validate_counts(shots, results1, results2, batch_size=batch_size)
- return
-
- if func is qml.sample:
- validate_samples(shots, results1, results2, batch_size=batch_size)
- return
-
- validate_expval(shots, results1, results2, batch_size=batch_size)
-
-
def test_apply_mid_measure():
"""Test that apply_mid_measure raises if applied to a batched state."""
with pytest.raises(ValueError, match="MidMeasureMP cannot be applied to batched states."):
@@ -263,7 +157,7 @@ def func(x, y, z):
func2 = qml.defer_measurements(func)
results2 = func2(*params)
- validate_measurements(measure_f, shots, results1, results2)
+ mcm_utils.validate_measurements(measure_f, shots, results1, results2)
@pytest.mark.parametrize("postselect", [None, 0, 1])
@@ -297,7 +191,7 @@ def func(x, y, z):
for measure_f, r1, r2 in zip(
[qml.counts, qml.expval, qml.probs, qml.sample, qml.var], results1, results2
):
- validate_measurements(measure_f, shots, r1, r2)
+ mcm_utils.validate_measurements(measure_f, shots, r1, r2)
@pytest.mark.parametrize(
@@ -355,7 +249,7 @@ def func(x):
results1 = func1(param)
results2 = func2(param)
- validate_measurements(measure_f, shots, results1, results2)
+ mcm_utils.validate_measurements(measure_f, shots, results1, results2)
@pytest.mark.parametrize(
@@ -422,7 +316,7 @@ def func(x, y):
results1 = func1(*param)
results2 = func2(*param)
- validate_measurements(measure_f, shots, results1, results2)
+ mcm_utils.validate_measurements(measure_f, shots, results1, results2)
grad1 = qml.grad(func)(*param)
grad2 = qml.grad(func2)(*param)
@@ -453,7 +347,7 @@ def func(x, y):
results1 = func1(*param)
results2 = func2(*param)
- validate_measurements(measure_fn, shots, results1, results2, batch_size=2)
+ mcm_utils.validate_measurements(measure_fn, shots, results1, results2, batch_size=2)
if measure_fn is qml.sample and postselect is None:
for i in range(2): # batch_size
@@ -509,7 +403,7 @@ def func(x, y):
results1 = func1(*param)
results2 = func2(*param)
- validate_measurements(qml.sample, shots, results1, results2, batch_size=None)
+ mcm_utils.validate_measurements(qml.sample, shots, results1, results2, batch_size=None)
evals = obs.eigvals()
for eig in evals:
@@ -638,4 +532,4 @@ def func(x):
results1 = func1(param)
results2 = func2(param)
- validate_measurements(measure_f, shots, results1, results2)
+ mcm_utils.validate_measurements(measure_f, shots, results1, results2)
diff --git a/tests/helpers/mcm_utils.py b/tests/helpers/mcm_utils.py
new file mode 100644
index 00000000000..fab28c2bfca
--- /dev/null
+++ b/tests/helpers/mcm_utils.py
@@ -0,0 +1,128 @@
+# Copyright 2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Pytest helper functions are defined in this module.
+"""
+from functools import reduce
+from typing import Iterable, Sequence
+
+import numpy as np
+
+import pennylane as qml
+
+
+def validate_counts(shots, results1, results2, batch_size=None):
+ """Compares two counts.
+
+ If the results are ``Sequence``s, loop over entries.
+
+ Fails if a key of ``results1`` is not found in ``results2``.
+ Passes if counts are too low, chosen as ``100``.
+ Otherwise, fails if counts differ by more than ``20`` plus 20 percent.
+ """
+ if isinstance(shots, Sequence):
+ assert isinstance(results1, tuple)
+ assert isinstance(results2, tuple)
+ assert len(results1) == len(results2) == len(shots)
+ for s, r1, r2 in zip(shots, results1, results2):
+ validate_counts(s, r1, r2, batch_size=batch_size)
+ return
+
+ if batch_size is not None:
+ assert isinstance(results1, Iterable)
+ assert isinstance(results2, Iterable)
+ assert len(results1) == len(results2) == batch_size
+ for r1, r2 in zip(results1, results2):
+ validate_counts(shots, r1, r2, batch_size=None)
+ return
+
+ for key1, val1 in results1.items():
+ val2 = results2[key1]
+ if abs(val1 + val2) > 100:
+ assert np.allclose(val1, val2, atol=20, rtol=0.2)
+
+
+def validate_samples(shots, results1, results2, batch_size=None):
+ """Compares two samples.
+
+ If the results are ``Sequence``s, loop over entries.
+
+ Fails if the results do not have the same shape, within ``20`` entries plus 20 percent.
+ This is to handle cases when post-selection yields variable shapes.
+ Otherwise, fails if the sums of samples differ by more than ``20`` plus 20 percent.
+ """
+ if isinstance(shots, Sequence):
+ assert isinstance(results1, tuple)
+ assert isinstance(results2, tuple)
+ assert len(results1) == len(results2) == len(shots)
+ for s, r1, r2 in zip(shots, results1, results2):
+ validate_samples(s, r1, r2, batch_size=batch_size)
+ return
+
+ if batch_size is not None:
+ assert isinstance(results1, Iterable)
+ assert isinstance(results2, Iterable)
+ assert len(results1) == len(results2) == batch_size
+ for r1, r2 in zip(results1, results2):
+ validate_samples(shots, r1, r2, batch_size=None)
+ return
+
+ sh1, sh2 = results1.shape[0], results2.shape[0]
+ assert np.allclose(sh1, sh2, atol=20, rtol=0.2)
+ assert results1.ndim == results2.ndim
+ if results2.ndim > 1:
+ assert results1.shape[1] == results2.shape[1]
+ np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2)
+
+
+def validate_expval(shots, results1, results2, batch_size=None):
+ """Compares two expval, probs or var.
+
+ If the results are ``Sequence``s, validate the average of items.
+
+ If ``shots is None``, validate using ``np.allclose``'s default parameters.
+ Otherwise, fails if the results do not match within ``0.01`` plus 20 percent.
+ """
+ if isinstance(shots, Sequence):
+ assert isinstance(results1, tuple)
+ assert isinstance(results2, tuple)
+ assert len(results1) == len(results2) == len(shots)
+ results1 = reduce(lambda x, y: x + y, results1) / len(results1)
+ results2 = reduce(lambda x, y: x + y, results2) / len(results2)
+ validate_expval(sum(shots), results1, results2, batch_size=batch_size)
+ return
+
+ if shots is None:
+ assert np.allclose(results1, results2)
+ return
+
+ if batch_size is not None:
+ assert len(results1) == len(results2) == batch_size
+ for r1, r2 in zip(results1, results2):
+ validate_expval(shots, r1, r2, batch_size=None)
+
+ assert np.allclose(results1, results2, atol=0.01, rtol=0.2)
+
+
+def validate_measurements(func, shots, results1, results2, batch_size=None):
+ """Calls the correct validation function based on measurement type."""
+ if func is qml.counts:
+ validate_counts(shots, results1, results2, batch_size=batch_size)
+ return
+
+ if func is qml.sample:
+ validate_samples(shots, results1, results2, batch_size=batch_size)
+ return
+
+ validate_expval(shots, results1, results2, batch_size=batch_size)
diff --git a/tests/test_compiler.py b/tests/test_compiler.py
index 370b4776fab..0b407bba814 100644
--- a/tests/test_compiler.py
+++ b/tests/test_compiler.py
@@ -17,11 +17,14 @@
# pylint: disable=import-outside-toplevel
from unittest.mock import patch
+import mcm_utils
+import numpy as np
import pytest
import pennylane as qml
from pennylane import numpy as np
from pennylane.compiler.compiler import CompileError
+from pennylane.transforms.dynamic_one_shot import fill_in_value
catalyst = pytest.importorskip("catalyst")
jax = pytest.importorskip("jax")
@@ -737,3 +740,74 @@ def circuit(x):
assert circuit(0.0) == 0
assert circuit(jnp.pi) == 1
+
+
+class TestCatalystMCMs:
+ """Test dynamic_one_shot with Catalyst."""
+
+ @pytest.mark.xfail(reason="requires simultaneous catalyst pr")
+ @pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs])
+ @pytest.mark.parametrize("meas_obj", [qml.PauliZ(0), [0], "mcm"])
+ # pylint: disable=too-many-arguments
+ def test_dynamic_one_shot_simple(self, measure_f, meas_obj):
+ """Tests that Catalyst yields the same results as PennyLane's DefaultQubit for a simple
+ circuit with a mid-circuit measurement."""
+ if measure_f in (qml.counts, qml.probs, qml.sample) and (
+ not isinstance(meas_obj, list) and not meas_obj == "mcm"
+ ):
+ pytest.skip("Can't use observables with counts, probs or sample")
+
+ if measure_f in (qml.var, qml.expval) and (isinstance(meas_obj, list)):
+ pytest.skip("Can't use wires/mcm lists with var or expval")
+
+ if measure_f == qml.var and (not isinstance(meas_obj, list) and not meas_obj == "mcm"):
+ pytest.xfail("isa")
+ shots = 8000
+
+ dq = qml.device("default.qubit", shots=shots, seed=8237945)
+
+ @qml.defer_measurements
+ @qml.qnode(dq)
+ def ref_func(x, y):
+ qml.RX(x, wires=0)
+ m0 = qml.measure(0)
+ qml.cond(m0, qml.RY)(y, wires=1)
+
+ meas_key = "wires" if isinstance(meas_obj, list) else "op"
+ meas_value = m0 if isinstance(meas_obj, str) else meas_obj
+ kwargs = {meas_key: meas_value}
+ if measure_f == qml.counts:
+ kwargs["all_outcomes"] = True
+ return measure_f(**kwargs)
+
+ dev = qml.device("lightning.qubit", wires=2, shots=shots)
+
+ @qml.qjit
+ @catalyst.dynamic_one_shot
+ @qml.qnode(dev)
+ def func(x, y):
+ qml.RX(x, wires=0)
+ m0 = catalyst.measure(0)
+
+ @catalyst.cond(m0 == 1)
+ def ansatz():
+ qml.RY(y, wires=1)
+
+ ansatz()
+
+ meas_key = "wires" if isinstance(meas_obj, list) else "op"
+ meas_value = m0 if isinstance(meas_obj, str) else meas_obj
+ kwargs = {meas_key: meas_value}
+ return measure_f(**kwargs)
+
+ params = jnp.pi / 4 * jnp.ones(2)
+ results0 = ref_func(*params)
+ results1 = func(*params)
+ if measure_f == qml.counts and isinstance(meas_obj, list):
+ results1 = {
+ format(int(state), f"0{len(meas_obj)}b"): count for state, count in zip(*results1)
+ }
+ if measure_f == qml.sample:
+ results0 = results0[results0 != fill_in_value]
+ results1 = results1[results1 != fill_in_value]
+ mcm_utils.validate_measurements(measure_f, shots, results1, results0)
diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py
index 6c908840359..16ac830cd11 100644
--- a/tests/transforms/test_dynamic_one_shot.py
+++ b/tests/transforms/test_dynamic_one_shot.py
@@ -45,7 +45,7 @@
def test_parse_native_mid_circuit_measurements_unsupported_meas(measurement):
circuit = qml.tape.QuantumScript([qml.RX(1.0, 0)], [measurement])
with pytest.raises(TypeError, match="Native mid-circuit measurement mode does not support"):
- parse_native_mid_circuit_measurements(circuit, [circuit], [[]])
+ parse_native_mid_circuit_measurements(circuit, [circuit], [np.empty((0,))])
def test_postselection_error_with_wrong_device():