Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynamic_one_shot tensorflow support + expanded testing #5973

Merged
merged 63 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
bc81271
Added test skeleton
mudit2812 Jun 17, 2024
c7eeeb4
Changelog entry
mudit2812 Jun 17, 2024
33782f6
Update tf interface check
mudit2812 Jun 17, 2024
11aa3c5
Fix casting rules for tf
mudit2812 Jun 18, 2024
70fa755
Fix more casting for tf support
mudit2812 Jun 18, 2024
30b154d
Add dev comments
mudit2812 Jun 18, 2024
3683853
Fix jax tests
mudit2812 Jun 18, 2024
ad2fcd6
Merge branch 'master' into dos-grad-tests
mudit2812 Jun 18, 2024
f1b2e96
Convert integration tests to unit tests
mudit2812 Jun 19, 2024
9742ec9
Merge branch 'master' into dos-grad-tests
mudit2812 Jun 19, 2024
3b1bf4e
[skip ci] Update tf casting validation
mudit2812 Jun 20, 2024
35235b5
Add finite-diff test
mudit2812 Jun 20, 2024
68aeee3
Update documentation for the new Algo debugging feature (#5894)
Jaybsoni Jun 26, 2024
689f063
Raise DeprecationWarning only when the qasm code contains measurement…
astralcai Jun 28, 2024
1f76752
Specifying `wire_order` for initial MPS in `default.tensor` (#5892)
PietropaoloFrisoni Jun 29, 2024
5c7ae2b
0.37 Release notes - v1 (#5918)
trbromley Jul 2, 2024
53d44d5
Incorporate `level` keyword in `draw` and `draw_mpl` (#5855)
Shiro-Raven Jun 19, 2024
8591b52
`add_noise` transform for adding noise models (#5718)
obliviateandsurrender Jun 19, 2024
40bf355
Fix `hadamard_grad` with wires-broadcasted measurements (#5860)
dwierichs Jun 20, 2024
a1bb41f
Update stable dependency files (#5809)
github-actions[bot] Jun 20, 2024
2511c73
Adding documentation for `qml.breakpoint()` and `qml.PLDB` (#5789)
Jaybsoni Jun 20, 2024
053341a
Minor update to dataset docs to lead users to quickstart and list_dat…
DSGuala Jun 20, 2024
4b6683c
Add quickstart page on mid-circuit measurements (#5870)
dwierichs Jun 20, 2024
463e8f5
Add deprecation warning to from_qasm (#5882)
astralcai Jun 20, 2024
656fec2
Support qubit operator in ``from_openfermion`` (#5881)
soranjh Jun 21, 2024
12cdd88
Add Dataset Attribute type for Pytrees (#5732)
brownj85 Jun 21, 2024
5305961
Minor program capture fixes (#5889)
albi3ro Jun 21, 2024
19cd65b
Added jax support to private function `_qsp_to_qsvt()` which handles …
Jaybsoni Jun 21, 2024
8fbde6d
Fix wrong PaulitRot decomposition with identity as pauli word (#5875)
EmilianoG-byte Jun 21, 2024
5485091
Small `dynamic_one_shot` change to account for Catalyst updates (#5888)
mudit2812 Jun 21, 2024
31c6d60
Preparation for release candidate 0.37.0 (#5898)
astralcai Jun 24, 2024
bbe33ff
Update documentation for the new Algo debugging feature (#5894)
Jaybsoni Jun 26, 2024
6f46449
Raise DeprecationWarning only when the qasm code contains measurement…
astralcai Jun 28, 2024
fb4856a
Specifying `wire_order` for initial MPS in `default.tensor` (#5892)
PietropaoloFrisoni Jun 29, 2024
386397d
0.37 Release notes - v1 (#5918)
trbromley Jul 2, 2024
4d29830
Merge branch 'v0.37.0-rc0' into dos-grad-tests
mudit2812 Jul 2, 2024
b87710d
Change interface validation strategy
mudit2812 Jul 9, 2024
2f9a2f3
Merge branch 'master' into dos-grad-tests
mudit2812 Jul 9, 2024
12a9afc
Revert debugger merge conflict
mudit2812 Jul 9, 2024
8543903
Changelog
mudit2812 Jul 9, 2024
b0073f4
Adding back interface arg for lightning compat
mudit2812 Jul 9, 2024
a59b56b
Merge branch 'master' into dos-grad-tests
mudit2812 Jul 9, 2024
bb3c20a
Fix MCMConfig copying mechanism; TODO add tests
mudit2812 Jul 9, 2024
59e6891
Added mcm_config as a qnode attr instead of a key to QNode.execute_kw…
mudit2812 Jul 10, 2024
ba55f7f
Merge branch 'master' into dos-grad-tests
mudit2812 Jul 12, 2024
35d273f
Add private postselect_mode
mudit2812 Jul 15, 2024
ea04b8a
Merge branch 'master' into dos-grad-tests
mudit2812 Jul 15, 2024
3d1f1a0
Add kwargs to mcm transform
mudit2812 Jul 15, 2024
dbe2ccf
Revert qnode location of mcm_config
mudit2812 Jul 19, 2024
9498af0
Merge branch 'master' into dos-grad-tests
mudit2812 Jul 19, 2024
167fe2b
Fix failing tests
mudit2812 Jul 19, 2024
84dd4ff
Tidy up execute
mudit2812 Jul 19, 2024
08810f9
Attempt number 3
mudit2812 Jul 19, 2024
100fcf3
Final tidying up
mudit2812 Jul 19, 2024
5aeb93f
Linting
mudit2812 Jul 19, 2024
bd6ed68
Update execute_kwargs logic
mudit2812 Jul 22, 2024
792f96e
Update execute_kwargs logic again
mudit2812 Jul 22, 2024
b491d23
Addressing PR review
mudit2812 Jul 22, 2024
4b33602
Merge branch 'master' into dos-grad-tests
mudit2812 Jul 22, 2024
877c2b6
Apply suggestions from code review
mudit2812 Jul 24, 2024
dad9adb
Merge branch 'master' into dos-grad-tests
mudit2812 Jul 24, 2024
d45f84d
Merge branch 'master' into dos-grad-tests
mudit2812 Jul 26, 2024
0afaa2c
Merge branch 'master' into dos-grad-tests
mudit2812 Jul 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/releases/changelog-0.37.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -929,4 +929,4 @@ Kenya Sakka,
Jay Soni,
Kazuki Tsuoka,
Haochen Paul Wang,
David Wierichs.
David Wierichs.
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
12 changes: 9 additions & 3 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@

<h3>New features since last release</h3>

* A new method `process_density_matrix` has been added to the `ProbabilityMP` and `DensityMatrixMP`
classes, allowing for more efficient handling of quantum density matrices, particularly with batch
processing support. This method simplifies the calculation of probabilities from quantum states
represented as density matrices.
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
[(#5830)](https://github.com/PennyLaneAI/pennylane/pull/5830)

* Resolved the bug in `qml.ThermalRelaxationError` where there was a typo from `tq` to `tg`.
[(#5988)](https://github.com/PennyLaneAI/pennylane/issues/5988)

* A new method `process_density_matrix` has been added to the `ProbabilityMP` and `DensityMatrixMP` classes, allowing for more efficient handling of quantum density matrices, particularly with batch processing support. This method simplifies the calculation of probabilities from quantum states represented as density matrices.
[(#5830)](https://github.com/PennyLaneAI/pennylane/pull/5830)

* The `qml.PrepSelPrep` template is added. The template implements a block-encoding of a linear
combination of unitaries.
[(#5756)](https://github.com/PennyLaneAI/pennylane/pull/5756)
Expand Down Expand Up @@ -57,6 +60,9 @@
* `QuantumScript.hash` is now cached, leading to performance improvements.
[(#5919)](https://github.com/PennyLaneAI/pennylane/pull/5919)

* `qml.dynamic_one_shot` now supports circuits using the `"tensorflow"` interface.
[(#5973)](https://github.com/PennyLaneAI/pennylane/pull/5973)

* The representation for `Wires` has now changed to be more copy-paste friendly.
[(#5958)](https://github.com/PennyLaneAI/pennylane/pull/5958)

Expand Down
5 changes: 1 addition & 4 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,7 @@ def preprocess(

transform_program.add_transform(validate_device_wires, self.wires, name=self.name)
transform_program.add_transform(
mid_circuit_measurements,
device=self,
mcm_config=config.mcm_config,
interface=config.interface,
mid_circuit_measurements, device=self, mcm_config=config.mcm_config
)
transform_program.add_transform(
decompose,
Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __post_init__(self):
None,
):
raise ValueError(f"Invalid mid-circuit measurements method '{self.mcm_method}'.")
if self.postselect_mode not in ("hw-like", "fill-shots", None):
if self.postselect_mode not in ("hw-like", "fill-shots", "pad-invalid-samples", None):
raise ValueError(f"Invalid postselection mode '{self.postselect_mode}'.")


Expand Down
5 changes: 2 additions & 3 deletions pennylane/devices/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,21 @@ def mid_circuit_measurements(
tape: qml.tape.QuantumTape,
device,
mcm_config=MCMConfig(),
interface=None,
**kwargs, # pylint: disable=unused-argument
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple[QuantumTapeBatch, PostprocessingFn]:
"""Provide the transform to handle mid-circuit measurements.

If the tape or device uses finite-shot, use the native implementation (i.e. no transform),
and use the ``qml.defer_measurements`` transform otherwise.
"""

if isinstance(mcm_config, dict):
mcm_config = MCMConfig(**mcm_config)
mcm_method = mcm_config.mcm_method
if mcm_method is None:
mcm_method = "one-shot" if tape.shots else "deferred"

if mcm_method == "one-shot":
return qml.dynamic_one_shot(tape, interface=interface)
return qml.dynamic_one_shot(tape, postselect_mode=mcm_config.postselect_mode)
if mcm_method == "tree-traversal":
return (tape,), null_postprocessing
return qml.defer_measurements(tape, device=device)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def apply_mid_measure(
axis = wire.toarray()[0]
slices = [slice(None)] * qml.math.ndim(state)
slices[axis] = 0
prob0 = qml.math.norm(state[tuple(slices)]) ** 2
prob0 = qml.math.real(qml.math.norm(state[tuple(slices)])) ** 2

if prng_key is not None:
# pylint: disable=import-outside-toplevel
Expand Down
1 change: 1 addition & 0 deletions pennylane/math/single_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def _take_autograd(tensor, indices, axis=None):
ar.autoray._SUBMODULE_ALIASES["tensorflow", "arctan"] = "tensorflow.math"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "arctan2"] = "tensorflow.math"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "mod"] = "tensorflow.math"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "logical_and"] = "tensorflow.math"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "kron"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "moveaxis"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "sinc"] = "tensorflow.experimental.numpy"
Expand Down
2 changes: 1 addition & 1 deletion pennylane/measurements/mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def __mul__(self, other):
return self._transform_bin_op(lambda a, b: a * b, other)

def __rmul__(self, other):
return self._apply(lambda v: other * v)
return self._apply(lambda v: other * qml.math.cast_like(v, other))
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved

def __truediv__(self, other):
return self._transform_bin_op(lambda a, b: a / b, other)
Expand Down
96 changes: 67 additions & 29 deletions pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def func(x, y):

aux_tapes = [init_auxiliary_tape(t) for t in tapes]

interface = kwargs.get("interface", None)
postselect_mode = kwargs.get("postselect_mode", None)

def reshape_data(array):
return qml.math.squeeze(qml.math.vstack(array))
Expand Down Expand Up @@ -161,7 +161,9 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None):
results = [
reshape_data(tuple(res[i] for res in results)) for i, _ in enumerate(results[0])
]
return parse_native_mid_circuit_measurements(tape, aux_tapes, results, interface=interface)
return parse_native_mid_circuit_measurements(
tape, aux_tapes, results, postselect_mode=postselect_mode
)

return aux_tapes, processing_fn

Expand Down Expand Up @@ -227,7 +229,7 @@ def parse_native_mid_circuit_measurements(
circuit: qml.tape.QuantumScript,
aux_tapes: qml.tape.QuantumScript,
results: TensorLike,
interface=None,
postselect_mode=None,
):
"""Combines, gathers and normalizes the results of native mid-circuit measurement runs.

Expand All @@ -247,20 +249,27 @@ def measurement_with_no_shots(measurement):
else np.nan
)

interface = interface or qml.math.get_deep_interface(circuit.data)
interface = qml.math.get_deep_interface(results)
interface = "numpy" if interface == "builtins" else interface
interface = "tensorflow" if interface == "tf" else interface
active_qjit = qml.compiler.active()

all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)]
n_mcms = len(all_mcms)
mcm_samples = qml.math.hstack(tuple(res.reshape((-1, 1)) for res in results[-n_mcms:]))
mcm_samples = qml.math.hstack(
tuple(qml.math.reshape(res, (-1, 1)) for res in results[-n_mcms:])
)
mcm_samples = qml.math.array(mcm_samples, like=interface)
# Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1
has_postselect = qml.math.array(
[[int(op.postselect is not None) for op in all_mcms]], like=interface
[[op.postselect is not None for op in all_mcms]],
like=interface,
dtype=mcm_samples.dtype,
)
postselect = qml.math.array(
[[0 if op.postselect is None else op.postselect for op in all_mcms]], like=interface
[[0 if op.postselect is None else op.postselect for op in all_mcms]],
like=interface,
dtype=mcm_samples.dtype,
)
is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1)
has_valid = qml.math.any(is_valid)
Expand All @@ -277,9 +286,11 @@ def measurement_with_no_shots(measurement):
if interface != "jax" and m.mv and not has_valid:
meas = measurement_with_no_shots(m)
elif m.mv and active_qjit:
meas = gather_mcm_qjit(m, mcm_samples, is_valid) # pragma: no cover
meas = gather_mcm_qjit(
m, mcm_samples, is_valid, postselect_mode=postselect_mode
) # pragma: no cover
elif m.mv:
meas = gather_mcm(m, mcm_samples, is_valid)
meas = gather_mcm(m, mcm_samples, is_valid, postselect_mode=postselect_mode)
elif interface != "jax" and not has_valid:
meas = measurement_with_no_shots(m)
m_count += 1
Expand All @@ -296,12 +307,15 @@ def measurement_with_no_shots(measurement):
# We return the sum of counts (`result[1]`) weighting by `is_valid`, which is `0` for invalid samples
if isinstance(m, CountsMP):
normalized_meas.append(
(result[0][0], qml.math.sum(result[1] * is_valid.reshape((-1, 1)), axis=0))
(
result[0][0],
qml.math.sum(result[1] * qml.math.reshape(is_valid, (-1, 1)), axis=0),
)
)
m_count += 1
continue
result = qml.math.squeeze(result)
meas = gather_non_mcm(m, result, is_valid)
meas = gather_non_mcm(m, result, is_valid, postselect_mode=postselect_mode)
m_count += 1
if isinstance(m, SampleMP):
meas = qml.math.squeeze(meas)
Expand All @@ -310,7 +324,7 @@ def measurement_with_no_shots(measurement):
return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0]


def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover
def gather_mcm_qjit(measurement, samples, is_valid, postselect_mode=None): # pragma: no cover
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
"""Process MCM measurements when the Catalyst compiler is active.

Args:
Expand All @@ -331,7 +345,7 @@ def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover
raise LookupError("MCM not found")
meas = qml.math.squeeze(meas)
if isinstance(measurement, (CountsMP, ProbabilityMP)):
interface = qml.math.get_deep_interface(is_valid)
interface = qml.math.get_interface(is_valid)
sum_valid = qml.math.sum(is_valid)
count_1 = qml.math.sum(meas * is_valid)
if isinstance(measurement, CountsMP):
Expand All @@ -341,10 +355,10 @@ def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover
if isinstance(measurement, ProbabilityMP):
counts = qml.math.array([sum_valid - count_1, count_1], like=interface)
return counts / sum_valid
return gather_non_mcm(measurement, meas, is_valid)
return gather_non_mcm(measurement, meas, is_valid, postselect_mode=postselect_mode)


def gather_non_mcm(measurement, samples, is_valid):
def gather_non_mcm(measurement, samples, is_valid, postselect_mode=None):
"""Combines, gathers and normalizes several measurements with trivial measurement values.

Args:
Expand All @@ -365,25 +379,39 @@ def gather_non_mcm(measurement, samples, is_valid):
if not measurement.all_outcomes:
tmp = Counter({k: v for k, v in tmp.items() if v > 0})
return dict(sorted(tmp.items()))
if isinstance(measurement, ExpectationMP):
return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid)
if isinstance(measurement, ProbabilityMP):
return qml.math.sum(samples * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum(is_valid)

if isinstance(measurement, SampleMP):
is_interface_jax = qml.math.get_deep_interface(is_valid) == "jax"
if is_interface_jax and samples.ndim == 2:
is_valid = is_valid.reshape((-1, 1))
if postselect_mode == "pad-invalid-samples" and samples.ndim == 2:
is_valid = qml.math.reshape(is_valid, (-1, 1))
return (
qml.math.where(is_valid, samples, fill_in_value)
if is_interface_jax
if postselect_mode == "pad-invalid-samples"
else samples[is_valid]
)

if (interface := qml.math.get_interface(is_valid)) == "tensorflow":
# Tensorflow requires arrays that are used for arithmetic with each other to have the
# same dtype. We don't cast if measuring samples as float tf.Tensors cannot be used to
# index other tf.Tensors (is_valid is used to index valid samples).
is_valid = qml.math.cast_like(is_valid, samples)

if isinstance(measurement, ExpectationMP):
return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid)
if isinstance(measurement, ProbabilityMP):
return qml.math.sum(samples * qml.math.reshape(is_valid, (-1, 1)), axis=0) / qml.math.sum(
is_valid
)

# VarianceMP
expval = qml.math.sum(samples * is_valid) / qml.math.sum(is_valid)
if interface == "tensorflow":
# Casting needed for tensorflow
samples = qml.math.cast_like(samples, expval)
is_valid = qml.math.cast_like(is_valid, expval)
return qml.math.sum((samples - expval) ** 2 * is_valid) / qml.math.sum(is_valid)


def gather_mcm(measurement, samples, is_valid):
def gather_mcm(measurement, samples, is_valid, postselect_mode=None):
"""Combines, gathers and normalizes several measurements with non-trivial measurement values.

Args:
Expand All @@ -404,20 +432,30 @@ def gather_mcm(measurement, samples, is_valid):
if isinstance(measurement, ProbabilityMP):
values = [list(m.branches.values()) for m in mv]
values = list(itertools.product(*values))
values = [qml.math.array([v], like=interface) for v in values]
values = [qml.math.array([v], like=interface, dtype=mcm_samples.dtype) for v in values]
# Need to use boolean functions explicitly as Tensorflow does not allow integer math
# on boolean arrays
counts = [
qml.math.sum(qml.math.all(mcm_samples == v, axis=1) * is_valid) for v in values
qml.math.count_nonzero(
qml.math.logical_and(qml.math.all(mcm_samples == v, axis=1), is_valid)
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
)
for v in values
]
counts = qml.math.array(counts, like=interface)
return counts / qml.math.sum(counts)
if isinstance(measurement, CountsMP):
mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples]
return gather_non_mcm(measurement, mcm_samples, is_valid)
return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_mode=postselect_mode)
mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface))
if isinstance(measurement, ProbabilityMP):
counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())]
# Need to use boolean functions explicitly as Tensorflow does not allow integer math
# on boolean arrays
counts = [
qml.math.count_nonzero(qml.math.logical_and((mcm_samples == v), is_valid))
for v in list(mv.branches.values())
]
counts = qml.math.array(counts, like=interface)
return counts / qml.math.sum(counts)
if isinstance(measurement, CountsMP):
mcm_samples = [{float(s): 1} for s in mcm_samples]
return gather_non_mcm(measurement, mcm_samples, is_valid)
return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_mode=postselect_mode)
43 changes: 33 additions & 10 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,28 @@ def _deprecated_arguments_warnings(
return tapes, override_shots, expand_fn, max_expansion, device_batch_transform


def _update_mcm_config(mcm_config: "qml.devices.MCMConfig", interface: str, finite_shots: bool):
"""Helper function to update the mid-circuit measurements configuration based on
execution parameters"""
if interface == "jax-jit" and mcm_config.mcm_method == "deferred":
# This is a current limitation of defer_measurements. "hw-like" behaviour is
# not yet accessible.
if mcm_config.postselect_mode == "hw-like":
raise ValueError(
"Using postselect_mode='hw-like' is not supported with jax-jit when using "
"mcm_method='deferred'."
)
mcm_config.postselect_mode = "fill-shots"

if (
finite_shots
and "jax" in interface
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
and mcm_config.mcm_method in (None, "one-shot")
and mcm_config.postselect_mode in (None, "hw-like")
):
mcm_config.postselect_mode = "pad-invalid-samples"
vincentmr marked this conversation as resolved.
Show resolved Hide resolved


def execute(
tapes: QuantumTapeBatch,
device: device_type,
Expand Down Expand Up @@ -697,16 +719,17 @@ def cost_fn(params, x):
)

# Mid-circuit measurement configuration validation
mcm_interface = _get_interface_name(tapes, "auto") if interface is None else interface
if mcm_interface == "jax-jit" and config.mcm_config.mcm_method == "deferred":
# This is a current limitation of defer_measurements. "hw-like" behaviour is
# not yet accessible.
if config.mcm_config.postselect_mode == "hw-like":
raise ValueError(
"Using postselect_mode='hw-like' is not supported with jax-jit when using "
"mcm_method='deferred'."
)
config.mcm_config.postselect_mode = "fill-shots"
mcm_interface = interface or _get_interface_name(tapes, "auto")
finite_shots = (
(
qml.measurements.Shots(device.shots)
if isinstance(device, qml.devices.LegacyDevice)
else device.shots
)
if override_shots is False
else override_shots
)
_update_mcm_config(config.mcm_config, mcm_interface, finite_shots)

is_gradient_transform = isinstance(gradient_fn, qml.transforms.core.TransformDispatcher)
transform_program, inner_transform = _make_transform_programs(
Expand Down
Loading
Loading