Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catalyst supports dynamic_one_shot #5766

Merged
merged 40 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
2ecf774
Initial commit for Catalyst MCM support.
vincentmr May 28, 2024
ae86184
Move data concat up the stack so that parse_native_mid_circuit_measur…
vincentmr May 29, 2024
9ac7d71
Couple ad hoc fix for active_jit
vincentmr May 29, 2024
eb3cef2
Move logic to gather_mcm_jit
vincentmr May 29, 2024
5b7d166
Merge branch 'master' into mcm_for_loop_catalyst3
vincentmr May 29, 2024
b4541d2
_override_postselect = True if MidCircuitMeasure; deal with all_outco…
vincentmr May 30, 2024
b1977ec
Merge branch 'master' into mcm_for_loop_catalyst3
vincentmr May 30, 2024
6a43e37
_override_postselect => bypass_postselect
vincentmr May 30, 2024
ad98e1e
Fix test_parse_native_mid_circuit_measurements_unsupported_meas
vincentmr May 30, 2024
3734269
Update pennylane/transforms/dynamic_one_shot.py
vincentmr May 31, 2024
91bb9d5
Update docstrings.
vincentmr May 31, 2024
e98d5f9
Move validate_measurements to conftest.
vincentmr May 31, 2024
c7ab8a9
Merge branch 'master' into mcm_for_loop_catalyst3
vincentmr May 31, 2024
33ada01
Merge branch 'master' into mcm_for_loop_catalyst3
vincentmr Jun 3, 2024
8a1ad82
Merge remote-tracking branch 'origin/master' into mcm_for_loop_catalyst3
vincentmr Jun 3, 2024
885af7a
Split MidCircuitMeasure logic
vincentmr Jun 3, 2024
9fb8d10
Put validate in tests/helpers/utils.py
vincentmr Jun 3, 2024
0bd74cb
Update pennylane/transforms/dynamic_one_shot.py
vincentmr Jun 3, 2024
7dd313d
Update pennylane/transforms/dynamic_one_shot.py
vincentmr Jun 3, 2024
7fd6447
Rename
vincentmr Jun 3, 2024
b5c1415
Fix docstring.
vincentmr Jun 3, 2024
ef4fc27
Rename gather_mcm_jit
vincentmr Jun 3, 2024
d2c9459
Merge branch 'master' into mcm_for_loop_catalyst3
vincentmr Jun 3, 2024
646a626
utils => mcm_utils.
vincentmr Jun 3, 2024
9e577f4
Merge branch 'master' into mcm_for_loop_catalyst3
vincentmr Jun 3, 2024
e0486d0
Merge branch 'master' into mcm_for_loop_catalyst3
vincentmr Jun 4, 2024
152ce07
Indent block.
vincentmr Jun 4, 2024
70e33db
Merge remote-tracking branch 'origin/master' into mcm_for_loop_catalyst3
vincentmr Jun 4, 2024
6434d31
Fix MCMConfig default.
vincentmr Jun 4, 2024
265b0e3
Merge branch 'master' into mcm_for_loop_catalyst3
vincentmr Jun 5, 2024
cba316a
# pragma: no cover
vincentmr Jun 5, 2024
542f007
Update pennylane/transforms/dynamic_one_shot.py
vincentmr Jun 5, 2024
c2540fe
Update pennylane/transforms/dynamic_one_shot.py
vincentmr Jun 5, 2024
2ec2b2b
Merge remote-tracking branch 'origin/master' into mcm_for_loop_catalyst3
vincentmr Jun 5, 2024
6ecb359
Export is_mcm in transforms for use in tape module.
vincentmr Jun 6, 2024
d27a244
Fix _validate_computational_basis_sampling?
vincentmr Jun 6, 2024
ba08fad
Merge branch 'master' into mcm_for_loop_catalyst3
vincentmr Jun 6, 2024
8095c6f
Remove new_operations logic
vincentmr Jun 7, 2024
bd2ac1d
Merge branch 'master' into mcm_for_loop_catalyst3
vincentmr Jun 7, 2024
8e1b4fc
Merge branch 'master' into mcm_for_loop_catalyst3
vincentmr Jun 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ coverage:
.PHONY:format
format:
ifdef check
isort --py 311 --profile black -l 100 -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests --check
isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests --check
black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests --check
else
isort --py 311 --profile black -l 100 -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests
isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests
black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests
endif

Expand Down
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@

<h4>Mid-circuit measurements and dynamic circuits</h4>

* The `dynamic_one_shot` transform is made compatible with the Catalyst compiler.
[(#5766)](https://github.com/PennyLaneAI/pennylane/pull/5766)

* The `dynamic_one_shot` transform uses a single auxiliary tape with a shot vector and `default.qubit` implements the loop over shots with `jax.vmap`.
[(#5617)](https://github.com/PennyLaneAI/pennylane/pull/5617)

Expand Down
2 changes: 1 addition & 1 deletion pennylane/measurements/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def circuit(x):
# remove nans
mask = qml.math.isnan(samples)
num_wires = shape[-1]
if np.any(mask):
if qml.math.any(mask):
mask = np.logical_not(np.any(mask, axis=tuple(range(1, samples.ndim))))
samples = samples[mask, ...]

Expand Down
102 changes: 78 additions & 24 deletions pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# pylint: disable=import-outside-toplevel
from collections import Counter
from typing import Callable, Sequence
from typing import Callable, Sequence, Tuple
vincentmr marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np

Expand Down Expand Up @@ -50,7 +50,7 @@ def null_postprocessing(results):
@transform
def dynamic_one_shot(
tape: qml.tape.QuantumTape, **kwargs
) -> tuple[Sequence[qml.tape.QuantumTape], Callable]:
) -> Tuple[Sequence[qml.tape.QuantumTape], Callable]:
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
"""Transform a QNode to into several one-shot tapes to support dynamic circuit execution.

Args:
Expand Down Expand Up @@ -118,6 +118,9 @@ def func(x, y):

aux_tapes = [init_auxiliary_tape(t) for t in tapes]

def reshape_data(array):
return qml.math.squeeze(qml.math.vstack(array))

def processing_fn(results, has_partitioned_shots=None, batched_results=None):
if batched_results is None and batch_size is not None:
# If broadcasting, recursively process the results for each batch. For each batch
Expand All @@ -141,6 +144,14 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None):
return tuple(final_results)
if not tape.shots.has_partitioned_shots:
results = results[0]

is_scalar = not isinstance(results[0], Sequence)
if is_scalar:
results = [reshape_data(tuple(results))]
else:
results = [
reshape_data(tuple(res[i] for res in results)) for i, _ in enumerate(results[0])
]
return parse_native_mid_circuit_measurements(tape, aux_tapes, results)

return aux_tapes, processing_fn
Expand Down Expand Up @@ -195,26 +206,35 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript):
new_measurements.append(SampleMP(obs=m.obs))
else:
new_measurements.append(m)
for op in circuit:
if is_mcm(op):
new_operations = []
for op in circuit.operations:
if "MidCircuitMeasure" in str(type(op)): # pragma: no cover
new_op = op
new_op.bypass_postselect = True
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
new_operations.append(new_op)
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
else:
new_operations.append(op)
if "MidCircuitMeasure" in str(type(op)): # pragma: no cover
new_measurements.append(qml.sample(op.out_classical_tracers[0]))
elif isinstance(op, MidMeasureMP):
new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res)))

return qml.tape.QuantumScript(
circuit.operations,
new_operations,
new_measurements,
shots=[1] * circuit.shots.total_shots,
trainable_params=circuit.trainable_params,
)


# pylint: disable=too-many-branches,too-many-statements
def parse_native_mid_circuit_measurements(
circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results: TensorLike
):
"""Combines, gathers and normalizes the results of native mid-circuit measurement runs.

Args:
circuit (QuantumTape): Initial ``QuantumScript``
aux_tapes (List[QuantumTape]): List of auxilary ``QuantumScript`` objects
circuit (QuantumTape): The original ``QuantumScript``
aux_tapes (List[QuantumTape]): List of auxiliary ``QuantumScript`` objects
results (TensorLike): Array of measurement results

Returns:
Expand All @@ -230,21 +250,12 @@ def measurement_with_no_shots(measurement):

interface = qml.math.get_deep_interface(circuit.data)
interface = "numpy" if interface == "builtins" else interface
active_qjit = qml.compiler.active()

all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)]
n_mcms = len(all_mcms)
post_process_tape = qml.tape.QuantumScript(
aux_tapes[0].operations,
aux_tapes[0].measurements[0:-n_mcms],
shots=aux_tapes[0].shots,
trainable_params=aux_tapes[0].trainable_params,
)
single_measurement = (
len(post_process_tape.measurements) == 0 and len(aux_tapes[0].measurements) == 1
)
mcm_samples = qml.math.array(
[[res] if single_measurement else res[-n_mcms::] for res in results], like=interface
)
mcm_samples = qml.math.hstack(tuple(res.reshape((-1, 1)) for res in results[-n_mcms:]))
mcm_samples = qml.math.array(mcm_samples, like=interface)
# Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1
has_postselect = qml.math.array(
[[int(op.postselect is not None) for op in all_mcms]], like=interface
Expand All @@ -257,7 +268,6 @@ def measurement_with_no_shots(measurement):
mid_meas = [op for op in circuit.operations if is_mcm(op)]
mcm_samples = [mcm_samples[:, i : i + 1] for i in range(n_mcms)]
mcm_samples = dict((k, v) for k, v in zip(mid_meas, mcm_samples))

normalized_meas = []
m_count = 0
for m in circuit.measurements:
Expand All @@ -267,18 +277,28 @@ def measurement_with_no_shots(measurement):
)
if interface != "jax" and m.mv and not has_valid:
meas = measurement_with_no_shots(m)
elif m.mv and active_qjit:
meas = gather_mcm_qjit(m, mcm_samples, is_valid) # pragma: no cover
elif m.mv:
meas = gather_mcm(m, mcm_samples, is_valid)
elif interface != "jax" and not has_valid:
meas = measurement_with_no_shots(m)
m_count += 1
else:
result = [res[m_count] for res in results]
result = results[m_count]
if not isinstance(m, CountsMP):
# We don't need to cast to arrays when using qml.counts. qml.math.array is not viable
# as it assumes all elements of the input are of builtin python types and not belonging
# to any particular interface
result = qml.math.stack(result, like=interface)
result = qml.math.array(result, like=interface)
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
if active_qjit: # pragma: no cover
if isinstance(m, CountsMP):
normalized_meas.append(
(result[0][0], qml.math.sum(result[1] * is_valid.reshape((-1, 1)), axis=0))
)
m_count += 1
continue
result = qml.math.squeeze(result)
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
meas = gather_non_mcm(m, result, is_valid)
m_count += 1
if isinstance(m, SampleMP):
Expand All @@ -288,6 +308,39 @@ def measurement_with_no_shots(measurement):
return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0]


def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover
"""Process MCM measurements when the Catalyst compiler is active.

Args:
measurement (MeasurementProcess): measurement
samples (dict): Mid-circuit measurement samples
is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at
each index specifies whether or not the respective sample is valid.

Returns:
TensorLike: The combined measurement outcome
"""
found, meas = False, None
for k, meas in samples.items():
if measurement.mv is k.out_classical_tracers[0]:
found = True
break
if not found:
raise LookupError("MCM not found")
meas = qml.math.squeeze(meas)
if isinstance(measurement, (CountsMP, ProbabilityMP)):
sum_valid = qml.math.sum(is_valid)
count_1 = qml.math.sum(meas * is_valid)
if isinstance(measurement, CountsMP):
return {0: sum_valid - count_1, 1: count_1}
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(measurement, ProbabilityMP):
counts = qml.math.array(
[sum_valid - count_1, count_1], like=qml.math.get_deep_interface(is_valid)
)
return counts / sum_valid
return gather_non_mcm(measurement, meas, is_valid)


def gather_non_mcm(measurement, samples, is_valid):
"""Combines, gathers and normalizes several measurements with trivial measurement values.

Expand All @@ -306,7 +359,8 @@ def gather_non_mcm(measurement, samples, is_valid):
tmp.update(
dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items())
)
tmp = Counter({k: v for k, v in tmp.items() if v > 0})
if not measurement.all_outcomes:
tmp = Counter({k: v for k, v in tmp.items() if v > 0})
return dict(sorted(tmp.items()))
if isinstance(measurement, ExpectationMP):
return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid)
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import contextlib
import os
import pathlib
import sys

import numpy as np
import pytest
Expand All @@ -26,6 +27,8 @@
from pennylane.devices import DefaultGaussian
from pennylane.operation import disable_new_opmath_cm, enable_new_opmath_cm

sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))
vincentmr marked this conversation as resolved.
Show resolved Hide resolved

# defaults
TOL = 1e-3
TF_TOL = 2e-2
Expand Down Expand Up @@ -206,6 +209,7 @@ def use_legacy_opmath():
yield cm


# pylint: disable=contextmanager-generator-missing-cleanup
@pytest.fixture(scope="function")
def use_new_opmath():
with enable_new_opmath_cm() as cm:
Expand Down
Loading
Loading