From 2ecf774ec5a85da71bce945c503c7d26c2135724 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Tue, 28 May 2024 19:46:17 +0000 Subject: [PATCH 01/26] Initial commit for Catalyst MCM support. --- pennylane/transforms/dynamic_one_shot.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 998cefa93c5..537ae87e051 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -195,8 +195,10 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): else: new_measurements.append(m) for op in circuit: - if is_mcm(op): + if isinstance(op, MidMeasureMP): new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) + if "MidCircuitMeasure" in str(type(op)): + new_measurements.append(qml.sample(op.out_classical_tracers[0])) return qml.tape.QuantumScript( circuit.operations, @@ -229,6 +231,7 @@ def measurement_with_no_shots(measurement): interface = qml.math.get_deep_interface(circuit.data) interface = "numpy" if interface == "builtins" else interface + active_jit = qml.compiler.active_compiler() all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)] n_mcms = len(all_mcms) @@ -243,7 +246,7 @@ def measurement_with_no_shots(measurement): ) mcm_samples = qml.math.array( [[res] if single_measurement else res[-n_mcms::] for res in results], like=interface - ) + ).reshape((-1, n_mcms)) # Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1 has_postselect = qml.math.array( [[int(op.postselect is not None) for op in all_mcms]], like=interface @@ -266,13 +269,24 @@ def measurement_with_no_shots(measurement): ) if interface != "jax" and m.mv and not has_valid: meas = measurement_with_no_shots(m) + elif m.mv and active_jit: + found = False + for k, meas in mcm_samples.items(): + if m.mv is k.out_classical_tracers[0]: + found = True + break + if not found: + raise LookupError("MCM not found") elif m.mv: meas = gather_mcm(m, mcm_samples, is_valid) elif interface != "jax" and not has_valid: meas = measurement_with_no_shots(m) m_count += 1 else: - result = [res[m_count] for res in results] + # result = [res[m_count] for res in results] + result = qml.math.squeeze( + qml.math.array([res[m_count] for res in results], like=interface) + ) if not isinstance(m, CountsMP): # We don't need to cast to arrays when using qml.counts. qml.math.array is not viable # as it assumes all elements of the input are of builtin python types and not belonging From ae86184eee336f574970d66d8c8fdce178f5ed8c Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Wed, 29 May 2024 17:32:58 +0000 Subject: [PATCH 02/26] Move data concat up the stack so that parse_native_mid_circuit_measurements can accept results with a broadcast dimension (for jitting). --- Makefile | 4 +-- pennylane/transforms/dynamic_one_shot.py | 32 +++++++++++------------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/Makefile b/Makefile index 0a7573ec0b1..5f7186e4549 100644 --- a/Makefile +++ b/Makefile @@ -70,10 +70,10 @@ coverage: .PHONY:format format: ifdef check - isort --py 311 --profile black -l 100 -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests --check + isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests --check black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests --check else - isort --py 311 --profile black -l 100 -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests + isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane ./tests black -t py39 -t py310 -t py311 -l 100 ./pennylane ./tests endif diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 537ae87e051..79280267093 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -117,6 +117,9 @@ def func(x, y): aux_tapes = [init_auxiliary_tape(t) for t in tapes] + def reshape_data(array): + return qml.math.squeeze(qml.math.vstack(array)) + def processing_fn(results, has_partitioned_shots=None, batched_results=None): if batched_results is None and batch_size is not None: # If broadcasting, recursively process the results for each batch. For each batch @@ -140,6 +143,14 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None): return tuple(final_results) if not tape.shots.has_partitioned_shots: results = results[0] + + is_scalar = not isinstance(results[0], Sequence) + if is_scalar: + results = [reshape_data(tuple(results))] + else: + results = [ + reshape_data(tuple(res[i] for res in results)) for i, _ in enumerate(results[0]) + ] return parse_native_mid_circuit_measurements(tape, aux_tapes, results) return aux_tapes, processing_fn @@ -235,18 +246,8 @@ def measurement_with_no_shots(measurement): all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)] n_mcms = len(all_mcms) - post_process_tape = qml.tape.QuantumScript( - aux_tapes[0].operations, - aux_tapes[0].measurements[0:-n_mcms], - shots=aux_tapes[0].shots, - trainable_params=aux_tapes[0].trainable_params, - ) - single_measurement = ( - len(post_process_tape.measurements) == 0 and len(aux_tapes[0].measurements) == 1 - ) - mcm_samples = qml.math.array( - [[res] if single_measurement else res[-n_mcms::] for res in results], like=interface - ).reshape((-1, n_mcms)) + mcm_samples = qml.math.hstack(tuple(res.reshape((-1, 1)) for res in results[-n_mcms:])) + mcm_samples = qml.math.array(mcm_samples, like=interface) # Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1 has_postselect = qml.math.array( [[int(op.postselect is not None) for op in all_mcms]], like=interface @@ -283,15 +284,12 @@ def measurement_with_no_shots(measurement): meas = measurement_with_no_shots(m) m_count += 1 else: - # result = [res[m_count] for res in results] - result = qml.math.squeeze( - qml.math.array([res[m_count] for res in results], like=interface) - ) + result = results[m_count] if not isinstance(m, CountsMP): # We don't need to cast to arrays when using qml.counts. qml.math.array is not viable # as it assumes all elements of the input are of builtin python types and not belonging # to any particular interface - result = qml.math.stack(result, like=interface) + result = qml.math.array(result, like=interface) meas = gather_non_mcm(m, result, is_valid) m_count += 1 if isinstance(m, SampleMP): From 9ac7d71205bf2067df386c42055488724592d647 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Wed, 29 May 2024 20:00:46 +0000 Subject: [PATCH 03/26] Couple ad hoc fix for active_jit --- pennylane/measurements/counts.py | 2 +- pennylane/transforms/dynamic_one_shot.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/pennylane/measurements/counts.py b/pennylane/measurements/counts.py index 903052edae1..18de33410e9 100644 --- a/pennylane/measurements/counts.py +++ b/pennylane/measurements/counts.py @@ -305,7 +305,7 @@ def circuit(x): # remove nans mask = qml.math.isnan(samples) num_wires = shape[-1] - if np.any(mask): + if qml.math.any(mask): mask = np.logical_not(np.any(mask, axis=tuple(range(1, samples.ndim)))) samples = samples[mask, ...] diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 79280267093..efa5449cf7f 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -219,6 +219,7 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): ) +# pylint: disable=too-many-branches,too-many-statements def parse_native_mid_circuit_measurements( circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results ): @@ -271,13 +272,18 @@ def measurement_with_no_shots(measurement): if interface != "jax" and m.mv and not has_valid: meas = measurement_with_no_shots(m) elif m.mv and active_jit: - found = False + found, meas = False, None for k, meas in mcm_samples.items(): if m.mv is k.out_classical_tracers[0]: found = True break if not found: raise LookupError("MCM not found") + meas = qml.math.squeeze(meas) + if isinstance(m, CountsMP): + count1 = qml.math.sum(meas * is_valid) + return {0: qml.math.sum(is_valid) - count1, 1: count1} + meas = gather_non_mcm(m, meas, is_valid) elif m.mv: meas = gather_mcm(m, mcm_samples, is_valid) elif interface != "jax" and not has_valid: @@ -290,6 +296,12 @@ def measurement_with_no_shots(measurement): # as it assumes all elements of the input are of builtin python types and not belonging # to any particular interface result = qml.math.array(result, like=interface) + if active_jit: + if isinstance(m, CountsMP): + normalized_meas.append((result[0][0], qml.math.sum(result[1], axis=0))) + m_count += 1 + continue + result = qml.math.squeeze(result) meas = gather_non_mcm(m, result, is_valid) m_count += 1 if isinstance(m, SampleMP): From eb3cef2b27b37722dfb300b6aab36f32d63e0c27 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Wed, 29 May 2024 20:46:10 +0000 Subject: [PATCH 04/26] Move logic to gather_mcm_jit --- pennylane/measurements/probs.py | 4 +-- pennylane/transforms/dynamic_one_shot.py | 45 +++++++++++++++++------- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/pennylane/measurements/probs.py b/pennylane/measurements/probs.py index a0ceb468f80..0d2c45d71f0 100644 --- a/pennylane/measurements/probs.py +++ b/pennylane/measurements/probs.py @@ -93,8 +93,8 @@ def circuit(): Note that the output shape of this measurement process depends on whether the device simulates qubit or continuous variable quantum systems. """ - if isinstance(op, MeasurementValue): - if len(op.measurements) > 1: + if isinstance(op, MeasurementValue) or qml.math.is_abstract(op): + if isinstance(op, MeasurementValue) and len(op.measurements) > 1: raise ValueError( "Cannot use qml.probs() when measuring multiple mid-circuit measurements collected " "using arithmetic operators. To collect probabilities for multiple mid-circuit " diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index efa5449cf7f..f083a0ed006 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -272,18 +272,7 @@ def measurement_with_no_shots(measurement): if interface != "jax" and m.mv and not has_valid: meas = measurement_with_no_shots(m) elif m.mv and active_jit: - found, meas = False, None - for k, meas in mcm_samples.items(): - if m.mv is k.out_classical_tracers[0]: - found = True - break - if not found: - raise LookupError("MCM not found") - meas = qml.math.squeeze(meas) - if isinstance(m, CountsMP): - count1 = qml.math.sum(meas * is_valid) - return {0: qml.math.sum(is_valid) - count1, 1: count1} - meas = gather_non_mcm(m, meas, is_valid) + meas = gather_mcm_jit(m, mcm_samples, is_valid) elif m.mv: meas = gather_mcm(m, mcm_samples, is_valid) elif interface != "jax" and not has_valid: @@ -311,6 +300,38 @@ def measurement_with_no_shots(measurement): return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0] +def gather_mcm_jit(circuit_measurement, measurement, is_valid): + """Combines, gathers and normalizes several measurements with trivial measurement values + when the Catalyst compiler is active. + + Args: + circuit_measurement (MeasurementProcess): measurement + measurement (TensorLike): measurement results + samples (List[dict]): Mid-circuit measurement samples + + Returns: + TensorLike: The combined measurement outcome + """ + found, meas = False, None + for k, meas in measurement.items(): + if circuit_measurement.mv is k.out_classical_tracers[0]: + found = True + break + if not found: + raise LookupError("MCM not found") + meas = qml.math.squeeze(meas) + if isinstance(circuit_measurement, CountsMP): + count1 = qml.math.sum(meas * is_valid) + return {0: qml.math.sum(is_valid) - count1, 1: count1} + if isinstance(circuit_measurement, ProbabilityMP): + count1 = qml.math.sum(meas * is_valid) + counts = qml.math.array( + [qml.math.sum(is_valid) - count1, count1], like=qml.math.get_deep_interface(is_valid) + ) + return counts / qml.math.sum(is_valid) + return gather_non_mcm(circuit_measurement, meas, is_valid) + + def gather_non_mcm(circuit_measurement, measurement, is_valid): """Combines, gathers and normalizes several measurements with trivial measurement values. From b4541d25ce5bdaf6c9421aef1f25315ffd727a25 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Thu, 30 May 2024 19:41:32 +0000 Subject: [PATCH 05/26] _override_postselect = True if MidCircuitMeasure; deal with all_outcomes. --- pennylane/transforms/dynamic_one_shot.py | 20 +++++++++++++------ .../test_default_qubit_native_mcm.py | 2 +- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index f083a0ed006..f46747e4915 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -186,6 +186,7 @@ def is_mcm(operation): return mcm or "MidCircuitMeasure" in str(type(operation)) +# pylint: disable=protected-access def init_auxiliary_tape(circuit: qml.tape.QuantumScript): """Creates an auxiliary circuit to perform one-shot mid-circuit measurement calculations. @@ -205,14 +206,19 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): new_measurements.append(SampleMP(obs=m.obs)) else: new_measurements.append(m) - for op in circuit: + new_operations = [] + for op in circuit.operations: if isinstance(op, MidMeasureMP): new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) if "MidCircuitMeasure" in str(type(op)): new_measurements.append(qml.sample(op.out_classical_tracers[0])) - + new_op = op + op._override_postselect = True + new_operations.append(new_op) + else: + new_operations.append(op) return qml.tape.QuantumScript( - circuit.operations, + new_operations, new_measurements, shots=[1] * circuit.shots.total_shots, trainable_params=circuit.trainable_params, @@ -261,7 +267,6 @@ def measurement_with_no_shots(measurement): mid_meas = [op for op in circuit.operations if is_mcm(op)] mcm_samples = [mcm_samples[:, i : i + 1] for i in range(n_mcms)] mcm_samples = dict((k, v) for k, v in zip(mid_meas, mcm_samples)) - normalized_meas = [] m_count = 0 for m in circuit.measurements: @@ -287,7 +292,9 @@ def measurement_with_no_shots(measurement): result = qml.math.array(result, like=interface) if active_jit: if isinstance(m, CountsMP): - normalized_meas.append((result[0][0], qml.math.sum(result[1], axis=0))) + normalized_meas.append( + (result[0][0], qml.math.sum(result[1] * is_valid.reshape((-1, 1)), axis=0)) + ) m_count += 1 continue result = qml.math.squeeze(result) @@ -349,7 +356,8 @@ def gather_non_mcm(circuit_measurement, measurement, is_valid): tmp.update( dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items()) ) - tmp = Counter({k: v for k, v in tmp.items() if v > 0}) + if not circuit_measurement.all_outcomes: + tmp = Counter({k: v for k, v in tmp.items() if v > 0}) return dict(sorted(tmp.items())) if isinstance(circuit_measurement, ExpectationMP): return qml.math.sum(measurement * is_valid) / qml.math.sum(is_valid) diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index e2184874cec..92664680cf9 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -292,7 +292,7 @@ def test_single_mcm_multiple_measure_obs(postselect, reset): @qml.qnode(dev) def func(x, y, z): obs_tape(x, y, z, reset=reset, postselect=postselect) - return qml.counts(qml.PauliZ(0)), qml.expval(qml.PauliY(1)) + return qml.counts(qml.PauliZ(0), all_outcomes=True), qml.expval(qml.PauliY(1)) func1 = func func2 = qml.defer_measurements(func) From 6a43e37f6f3c6d580b119c7c3a49ffe3207aa345 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Thu, 30 May 2024 19:43:39 +0000 Subject: [PATCH 06/26] _override_postselect => bypass_postselect --- pennylane/transforms/dynamic_one_shot.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index f46747e4915..636bf41f672 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -186,7 +186,6 @@ def is_mcm(operation): return mcm or "MidCircuitMeasure" in str(type(operation)) -# pylint: disable=protected-access def init_auxiliary_tape(circuit: qml.tape.QuantumScript): """Creates an auxiliary circuit to perform one-shot mid-circuit measurement calculations. @@ -213,7 +212,7 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): if "MidCircuitMeasure" in str(type(op)): new_measurements.append(qml.sample(op.out_classical_tracers[0])) new_op = op - op._override_postselect = True + op.bypass_postselect = True new_operations.append(new_op) else: new_operations.append(op) From ad98e1e457b355b74fa4bd667f04615ed9b76299 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Thu, 30 May 2024 20:19:40 +0000 Subject: [PATCH 07/26] Fix test_parse_native_mid_circuit_measurements_unsupported_meas --- tests/transforms/test_dynamic_one_shot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index ce9d67a429b..efe51c3191f 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -45,7 +45,7 @@ def test_parse_native_mid_circuit_measurements_unsupported_meas(measurement): circuit = qml.tape.QuantumScript([qml.RX(1.0, 0)], [measurement]) with pytest.raises(TypeError, match="Native mid-circuit measurement mode does not support"): - parse_native_mid_circuit_measurements(circuit, [circuit], [[]]) + parse_native_mid_circuit_measurements(circuit, [circuit], [np.empty((0,))]) def test_postselection_error_with_wrong_device(): From 3734269e77418c11c9ffe1f61c773447f6ed7f59 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Fri, 31 May 2024 14:23:59 -0400 Subject: [PATCH 08/26] Update pennylane/transforms/dynamic_one_shot.py Co-authored-by: David Ittah --- pennylane/transforms/dynamic_one_shot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 636bf41f672..fa9965a3074 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -209,7 +209,7 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): for op in circuit.operations: if isinstance(op, MidMeasureMP): new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) - if "MidCircuitMeasure" in str(type(op)): + elif "MidCircuitMeasure" in str(type(op)): new_measurements.append(qml.sample(op.out_classical_tracers[0])) new_op = op op.bypass_postselect = True From 91bb9d540103b808e3a4ea252c6b5174bfb2626f Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Fri, 31 May 2024 18:35:49 +0000 Subject: [PATCH 09/26] Update docstrings. --- doc/releases/changelog-dev.md | 3 +++ pennylane/measurements/probs.py | 4 ++-- pennylane/transforms/dynamic_one_shot.py | 10 +++++----- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 323549d8155..867418622aa 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -26,6 +26,9 @@

Mid-circuit measurements and dynamic circuits

+* The `dynamic_one_shot` transform is made compatible with the Catalyst compiler. + [(#5766)](https://github.com/PennyLaneAI/pennylane/pull/5766) + * The `dynamic_one_shot` transform uses a single auxiliary tape with a shot vector and `default.qubit` implements the loop over shots with `jax.vmap`. [(#5617)](https://github.com/PennyLaneAI/pennylane/pull/5617) diff --git a/pennylane/measurements/probs.py b/pennylane/measurements/probs.py index 4773beba70f..c1c0dc45c6e 100644 --- a/pennylane/measurements/probs.py +++ b/pennylane/measurements/probs.py @@ -93,8 +93,8 @@ def circuit(): Note that the output shape of this measurement process depends on whether the device simulates qubit or continuous variable quantum systems. """ - if isinstance(op, MeasurementValue) or qml.math.is_abstract(op): - if isinstance(op, MeasurementValue) and len(op.measurements) > 1: + if isinstance(op, MeasurementValue): + if len(op.measurements) > 1: raise ValueError( "Cannot use qml.probs() when measuring multiple mid-circuit measurements collected " "using arithmetic operators. To collect probabilities for multiple mid-circuit " diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index fa9965a3074..29a6eca8fa4 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -209,10 +209,10 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): for op in circuit.operations: if isinstance(op, MidMeasureMP): new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) - elif "MidCircuitMeasure" in str(type(op)): + if "MidCircuitMeasure" in str(type(op)): new_measurements.append(qml.sample(op.out_classical_tracers[0])) new_op = op - op.bypass_postselect = True + new_op.bypass_postselect = True new_operations.append(new_op) else: new_operations.append(op) @@ -231,9 +231,9 @@ def parse_native_mid_circuit_measurements( """Combines, gathers and normalizes the results of native mid-circuit measurement runs. Args: - circuit (QuantumTape): A one-shot (auxiliary) QuantumScript - all_shot_meas (Sequence[Any]): List of accumulated measurement results - mcm_shot_meas (Sequence[dict]): List of dictionaries containing the mid-circuit measurement results of each shot + circuit (QuantumTape): The original tape + aux_tapes tuple[QuantumTape]: A tuple of transformed tapes + results tuple[Sequence[Any]]: A tuple of results with length n-shots Returns: tuple(TensorLike): The results of the simulation From e98d5f91867e36d18ea5e2818fd86c165cb9fa5d Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Fri, 31 May 2024 18:55:52 +0000 Subject: [PATCH 10/26] Move validate_measurements to conftest. --- tests/conftest.py | 109 +++++++++++++++++ .../test_default_qubit_native_mcm.py | 110 +----------------- tests/test_compiler.py | 73 ++++++++++++ 3 files changed, 184 insertions(+), 108 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1b90d1f20f5..8b3e1efd3ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,8 @@ import contextlib import os import pathlib +from functools import reduce +from typing import Iterable, Sequence import numpy as np import pytest @@ -206,6 +208,7 @@ def use_legacy_opmath(): yield cm +# pylint: disable=contextmanager-generator-missing-cleanup @pytest.fixture(scope="function") def use_new_opmath(): with enable_new_opmath_cm() as cm: @@ -332,3 +335,109 @@ def pytest_runtest_setup(item): pytest.skip( f"\nTest {item.nodeid} only runs with {allowed_interfaces} interfaces(s) but {b} interface provided", ) + + +def validate_counts(shots, results1, results2, batch_size=None): + """Compares two counts. + + If the results are ``Sequence``s, loop over entries. + + Fails if a key of ``results1`` is not found in ``results2``. + Passes if counts are too low, chosen as ``100``. + Otherwise, fails if counts differ by more than ``20`` plus 20 percent. + """ + if isinstance(shots, Sequence): + assert isinstance(results1, tuple) + assert isinstance(results2, tuple) + assert len(results1) == len(results2) == len(shots) + for s, r1, r2 in zip(shots, results1, results2): + validate_counts(s, r1, r2, batch_size=batch_size) + return + + if batch_size is not None: + assert isinstance(results1, Iterable) + assert isinstance(results2, Iterable) + assert len(results1) == len(results2) == batch_size + for r1, r2 in zip(results1, results2): + validate_counts(shots, r1, r2, batch_size=None) + return + + for key1, val1 in results1.items(): + val2 = results2[key1] + if abs(val1 + val2) > 100: + assert np.allclose(val1, val2, atol=20, rtol=0.2) + + +def validate_samples(shots, results1, results2, batch_size=None): + """Compares two samples. + + If the results are ``Sequence``s, loop over entries. + + Fails if the results do not have the same shape, within ``20`` entries plus 20 percent. + This is to handle cases when post-selection yields variable shapes. + Otherwise, fails if the sums of samples differ by more than ``20`` plus 20 percent. + """ + if isinstance(shots, Sequence): + assert isinstance(results1, tuple) + assert isinstance(results2, tuple) + assert len(results1) == len(results2) == len(shots) + for s, r1, r2 in zip(shots, results1, results2): + validate_samples(s, r1, r2, batch_size=batch_size) + return + + if batch_size is not None: + assert isinstance(results1, Iterable) + assert isinstance(results2, Iterable) + assert len(results1) == len(results2) == batch_size + for r1, r2 in zip(results1, results2): + validate_samples(shots, r1, r2, batch_size=None) + return + + sh1, sh2 = results1.shape[0], results2.shape[0] + assert np.allclose(sh1, sh2, atol=20, rtol=0.2) + assert results1.ndim == results2.ndim + if results2.ndim > 1: + assert results1.shape[1] == results2.shape[1] + np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2) + + +def validate_expval(shots, results1, results2, batch_size=None): + """Compares two expval, probs or var. + + If the results are ``Sequence``s, validate the average of items. + + If ``shots is None``, validate using ``np.allclose``'s default parameters. + Otherwise, fails if the results do not match within ``0.01`` plus 20 percent. + """ + if isinstance(shots, Sequence): + assert isinstance(results1, tuple) + assert isinstance(results2, tuple) + assert len(results1) == len(results2) == len(shots) + results1 = reduce(lambda x, y: x + y, results1) / len(results1) + results2 = reduce(lambda x, y: x + y, results2) / len(results2) + validate_expval(sum(shots), results1, results2, batch_size=batch_size) + return + + if shots is None: + assert np.allclose(results1, results2) + return + + if batch_size is not None: + assert len(results1) == len(results2) == batch_size + for r1, r2 in zip(results1, results2): + validate_expval(shots, r1, r2, batch_size=None) + + assert np.allclose(results1, results2, atol=0.01, rtol=0.2) + + +def validate_measurements(func, shots, results1, results2, batch_size=None): + """Calls the correct validation function based on measurement type.""" + if func is qml.counts: + validate_counts(shots, results1, results2, batch_size=batch_size) + return + + if func is qml.sample: + validate_samples(shots, results1, results2, batch_size=batch_size) + return + + validate_expval(shots, results1, results2, batch_size=batch_size) diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index 92664680cf9..bc9cea8f759 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for default qubit preprocessing.""" -from functools import reduce -from typing import Iterable, Sequence +from typing import Sequence import numpy as np import pytest +from conftest import validate_measurements import pennylane as qml from pennylane.devices.qubit.apply_operation import MidMeasureMP, apply_mid_measure @@ -31,112 +31,6 @@ def get_device(**kwargs): return qml.device("default.qubit", **kwargs) -def validate_counts(shots, results1, results2, batch_size=None): - """Compares two counts. - - If the results are ``Sequence``s, loop over entries. - - Fails if a key of ``results1`` is not found in ``results2``. - Passes if counts are too low, chosen as ``100``. - Otherwise, fails if counts differ by more than ``20`` plus 20 percent. - """ - if isinstance(shots, Sequence): - assert isinstance(results1, tuple) - assert isinstance(results2, tuple) - assert len(results1) == len(results2) == len(shots) - for s, r1, r2 in zip(shots, results1, results2): - validate_counts(s, r1, r2, batch_size=batch_size) - return - - if batch_size is not None: - assert isinstance(results1, Iterable) - assert isinstance(results2, Iterable) - assert len(results1) == len(results2) == batch_size - for r1, r2 in zip(results1, results2): - validate_counts(shots, r1, r2, batch_size=None) - return - - for key1, val1 in results1.items(): - val2 = results2[key1] - if abs(val1 + val2) > 100: - assert np.allclose(val1, val2, atol=20, rtol=0.2) - - -def validate_samples(shots, results1, results2, batch_size=None): - """Compares two samples. - - If the results are ``Sequence``s, loop over entries. - - Fails if the results do not have the same shape, within ``20`` entries plus 20 percent. - This is to handle cases when post-selection yields variable shapes. - Otherwise, fails if the sums of samples differ by more than ``20`` plus 20 percent. - """ - if isinstance(shots, Sequence): - assert isinstance(results1, tuple) - assert isinstance(results2, tuple) - assert len(results1) == len(results2) == len(shots) - for s, r1, r2 in zip(shots, results1, results2): - validate_samples(s, r1, r2, batch_size=batch_size) - return - - if batch_size is not None: - assert isinstance(results1, Iterable) - assert isinstance(results2, Iterable) - assert len(results1) == len(results2) == batch_size - for r1, r2 in zip(results1, results2): - validate_samples(shots, r1, r2, batch_size=None) - return - - sh1, sh2 = results1.shape[0], results2.shape[0] - assert np.allclose(sh1, sh2, atol=20, rtol=0.2) - assert results1.ndim == results2.ndim - if results2.ndim > 1: - assert results1.shape[1] == results2.shape[1] - np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2) - - -def validate_expval(shots, results1, results2, batch_size=None): - """Compares two expval, probs or var. - - If the results are ``Sequence``s, validate the average of items. - - If ``shots is None``, validate using ``np.allclose``'s default parameters. - Otherwise, fails if the results do not match within ``0.01`` plus 20 percent. - """ - if isinstance(shots, Sequence): - assert isinstance(results1, tuple) - assert isinstance(results2, tuple) - assert len(results1) == len(results2) == len(shots) - results1 = reduce(lambda x, y: x + y, results1) / len(results1) - results2 = reduce(lambda x, y: x + y, results2) / len(results2) - validate_expval(sum(shots), results1, results2, batch_size=batch_size) - return - - if shots is None: - assert np.allclose(results1, results2) - return - - if batch_size is not None: - assert len(results1) == len(results2) == batch_size - for r1, r2 in zip(results1, results2): - validate_expval(shots, r1, r2, batch_size=None) - - assert np.allclose(results1, results2, atol=0.01, rtol=0.2) - - -def validate_measurements(func, shots, results1, results2, batch_size=None): - """Calls the correct validation function based on measurement type.""" - if func is qml.counts: - validate_counts(shots, results1, results2, batch_size=batch_size) - return - - if func is qml.sample: - validate_samples(shots, results1, results2, batch_size=batch_size) - return - - validate_expval(shots, results1, results2, batch_size=batch_size) - - def test_apply_mid_measure(): """Test that apply_mid_measure raises if applied to a batched state.""" with pytest.raises(ValueError, match="MidMeasureMP cannot be applied to batched states."): diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 0df35a567e2..e747d178351 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -18,10 +18,12 @@ from unittest.mock import patch import pytest +from conftest import validate_measurements import pennylane as qml from pennylane import numpy as np from pennylane.compiler.compiler import CompileError +from pennylane.transforms.dynamic_one_shot import fill_in_value catalyst = pytest.importorskip("catalyst") jax = pytest.importorskip("jax") @@ -737,3 +739,74 @@ def circuit(x): assert circuit(0.0) == 0 assert circuit(jnp.pi) == 1 + + +class TestCatalystMCMs: + """Test dynamic_one_shot with Catalyst.""" + + @pytest.mark.xfail(reason="requires simultaneous catalyst pr") + @pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs]) + @pytest.mark.parametrize("meas_obj", [qml.PauliZ(0), [0], "mcm"]) + # pylint: disable=too-many-arguments + def test_dynamic_one_shot_simple(self, measure_f, meas_obj): + """Tests that Catalyst yields the same results as PennyLane's DefaultQubit for a simple + circuit with a mid-circuit measurement.""" + if measure_f in (qml.counts, qml.probs, qml.sample) and ( + not isinstance(meas_obj, list) and not meas_obj == "mcm" + ): + pytest.skip("Can't use observables with counts, probs or sample") + + if measure_f in (qml.var, qml.expval) and (isinstance(meas_obj, list)): + pytest.skip("Can't use wires/mcm lists with var or expval") + + if measure_f == qml.var and (not isinstance(meas_obj, list) and not meas_obj == "mcm"): + pytest.xfail("isa") + shots = 8000 + + dq = qml.device("default.qubit", shots=shots, seed=8237945) + + @qml.defer_measurements + @qml.qnode(dq) + def ref_func(x, y): + qml.RX(x, wires=0) + m0 = qml.measure(0) + qml.cond(m0, qml.RY)(y, wires=1) + + meas_key = "wires" if isinstance(meas_obj, list) else "op" + meas_value = m0 if isinstance(meas_obj, str) else meas_obj + kwargs = {meas_key: meas_value} + if measure_f == qml.counts: + kwargs["all_outcomes"] = True + return measure_f(**kwargs) + + dev = qml.device("lightning.qubit", wires=2, shots=shots) + + @qml.qjit + @catalyst.dynamic_one_shot + @qml.qnode(dev) + def func(x, y): + qml.RX(x, wires=0) + m0 = catalyst.measure(0) + + @catalyst.cond(m0 == 1) + def ansatz(): + qml.RY(y, wires=1) + + ansatz() + + meas_key = "wires" if isinstance(meas_obj, list) else "op" + meas_value = m0 if isinstance(meas_obj, str) else meas_obj + kwargs = {meas_key: meas_value} + return measure_f(**kwargs) + + params = jnp.pi / 4 * jnp.ones(2) + results0 = ref_func(*params) + results1 = func(*params) + if measure_f == qml.counts and isinstance(meas_obj, list): + results1 = { + format(int(state), f"0{len(meas_obj)}b"): count for state, count in zip(*results1) + } + if measure_f == qml.sample: + results0 = results0[results0 != fill_in_value] + results1 = results1[results1 != fill_in_value] + validate_measurements(measure_f, shots, results1, results0) From 885af7a2e8bf47a8e32974668729b00e6ba9336f Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Mon, 3 Jun 2024 14:09:33 +0000 Subject: [PATCH 11/26] Split MidCircuitMeasure logic --- pennylane/transforms/dynamic_one_shot.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 29a6eca8fa4..58b1d784ec4 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -207,15 +207,16 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): new_measurements.append(m) new_operations = [] for op in circuit.operations: - if isinstance(op, MidMeasureMP): - new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) if "MidCircuitMeasure" in str(type(op)): - new_measurements.append(qml.sample(op.out_classical_tracers[0])) new_op = op new_op.bypass_postselect = True new_operations.append(new_op) else: new_operations.append(op) + if isinstance(op, MidMeasureMP): + new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) + elif "MidCircuitMeasure" in str(type(op)): + new_measurements.append(qml.sample(op.out_classical_tracers[0])) return qml.tape.QuantumScript( new_operations, new_measurements, @@ -327,14 +328,16 @@ def gather_mcm_jit(circuit_measurement, measurement, is_valid): raise LookupError("MCM not found") meas = qml.math.squeeze(meas) if isinstance(circuit_measurement, CountsMP): - count1 = qml.math.sum(meas * is_valid) - return {0: qml.math.sum(is_valid) - count1, 1: count1} + sum_valid = qml.math.sum(is_valid) + count_1 = qml.math.sum(meas * is_valid) + return {0: sum_valid - count_1, 1: count_1} if isinstance(circuit_measurement, ProbabilityMP): - count1 = qml.math.sum(meas * is_valid) + sum_valid = qml.math.sum(is_valid) + count_1 = qml.math.sum(meas * is_valid) counts = qml.math.array( - [qml.math.sum(is_valid) - count1, count1], like=qml.math.get_deep_interface(is_valid) + [sum_valid - count_1, count_1], like=qml.math.get_deep_interface(is_valid) ) - return counts / qml.math.sum(is_valid) + return counts / sum_valid return gather_non_mcm(circuit_measurement, meas, is_valid) From 9fb8d10f3b0ff6dec766765f0ddb2ed1d807c4ef Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Mon, 3 Jun 2024 14:30:34 +0000 Subject: [PATCH 12/26] Put validate in tests/helpers/utils.py --- tests/conftest.py | 111 +-------------- .../test_default_qubit_native_mcm.py | 30 ++-- tests/helpers/utils.py | 128 ++++++++++++++++++ tests/test_compiler.py | 5 +- 4 files changed, 149 insertions(+), 125 deletions(-) create mode 100644 tests/helpers/utils.py diff --git a/tests/conftest.py b/tests/conftest.py index 8b3e1efd3ab..7f2eb29e46f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,8 +18,7 @@ import contextlib import os import pathlib -from functools import reduce -from typing import Iterable, Sequence +import sys import numpy as np import pytest @@ -28,6 +27,8 @@ from pennylane.devices import DefaultGaussian from pennylane.operation import disable_new_opmath_cm, enable_new_opmath_cm +sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) + # defaults TOL = 1e-3 TF_TOL = 2e-2 @@ -335,109 +336,3 @@ def pytest_runtest_setup(item): pytest.skip( f"\nTest {item.nodeid} only runs with {allowed_interfaces} interfaces(s) but {b} interface provided", ) - - -def validate_counts(shots, results1, results2, batch_size=None): - """Compares two counts. - - If the results are ``Sequence``s, loop over entries. - - Fails if a key of ``results1`` is not found in ``results2``. - Passes if counts are too low, chosen as ``100``. - Otherwise, fails if counts differ by more than ``20`` plus 20 percent. - """ - if isinstance(shots, Sequence): - assert isinstance(results1, tuple) - assert isinstance(results2, tuple) - assert len(results1) == len(results2) == len(shots) - for s, r1, r2 in zip(shots, results1, results2): - validate_counts(s, r1, r2, batch_size=batch_size) - return - - if batch_size is not None: - assert isinstance(results1, Iterable) - assert isinstance(results2, Iterable) - assert len(results1) == len(results2) == batch_size - for r1, r2 in zip(results1, results2): - validate_counts(shots, r1, r2, batch_size=None) - return - - for key1, val1 in results1.items(): - val2 = results2[key1] - if abs(val1 + val2) > 100: - assert np.allclose(val1, val2, atol=20, rtol=0.2) - - -def validate_samples(shots, results1, results2, batch_size=None): - """Compares two samples. - - If the results are ``Sequence``s, loop over entries. - - Fails if the results do not have the same shape, within ``20`` entries plus 20 percent. - This is to handle cases when post-selection yields variable shapes. - Otherwise, fails if the sums of samples differ by more than ``20`` plus 20 percent. - """ - if isinstance(shots, Sequence): - assert isinstance(results1, tuple) - assert isinstance(results2, tuple) - assert len(results1) == len(results2) == len(shots) - for s, r1, r2 in zip(shots, results1, results2): - validate_samples(s, r1, r2, batch_size=batch_size) - return - - if batch_size is not None: - assert isinstance(results1, Iterable) - assert isinstance(results2, Iterable) - assert len(results1) == len(results2) == batch_size - for r1, r2 in zip(results1, results2): - validate_samples(shots, r1, r2, batch_size=None) - return - - sh1, sh2 = results1.shape[0], results2.shape[0] - assert np.allclose(sh1, sh2, atol=20, rtol=0.2) - assert results1.ndim == results2.ndim - if results2.ndim > 1: - assert results1.shape[1] == results2.shape[1] - np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2) - - -def validate_expval(shots, results1, results2, batch_size=None): - """Compares two expval, probs or var. - - If the results are ``Sequence``s, validate the average of items. - - If ``shots is None``, validate using ``np.allclose``'s default parameters. - Otherwise, fails if the results do not match within ``0.01`` plus 20 percent. - """ - if isinstance(shots, Sequence): - assert isinstance(results1, tuple) - assert isinstance(results2, tuple) - assert len(results1) == len(results2) == len(shots) - results1 = reduce(lambda x, y: x + y, results1) / len(results1) - results2 = reduce(lambda x, y: x + y, results2) / len(results2) - validate_expval(sum(shots), results1, results2, batch_size=batch_size) - return - - if shots is None: - assert np.allclose(results1, results2) - return - - if batch_size is not None: - assert len(results1) == len(results2) == batch_size - for r1, r2 in zip(results1, results2): - validate_expval(shots, r1, r2, batch_size=None) - - assert np.allclose(results1, results2, atol=0.01, rtol=0.2) - - -def validate_measurements(func, shots, results1, results2, batch_size=None): - """Calls the correct validation function based on measurement type.""" - if func is qml.counts: - validate_counts(shots, results1, results2, batch_size=batch_size) - return - - if func is qml.sample: - validate_samples(shots, results1, results2, batch_size=batch_size) - return - - validate_expval(shots, results1, results2, batch_size=batch_size) diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index bc9cea8f759..b3674d87414 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -16,7 +16,7 @@ import numpy as np import pytest -from conftest import validate_measurements +import utils import pennylane as qml from pennylane.devices.qubit.apply_operation import MidMeasureMP, apply_mid_measure @@ -123,7 +123,7 @@ def func(x, y): results1 = func1(*params) results2 = func2(*params) - validate_measurements(measure_f, shots, results1, results2) + utils.validate_measurements(measure_f, shots, results1, results2) # pylint: disable=unused-argument @@ -171,7 +171,7 @@ def func(x, y, z): results1 = func1(*params) results2 = func2(*params) - validate_measurements(measure_f, shots, results1, results2) + utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("postselect", [None, 0, 1]) @@ -195,7 +195,7 @@ def func(x, y, z): results2 = func2(*params) for measure_f, res1, res2 in zip([qml.counts, qml.expval], results1, results2): - validate_measurements(measure_f, 5000, res1, res2) + utils.validate_measurements(measure_f, 5000, res1, res2) @pytest.mark.parametrize("shots", [None, 3000, [3000, 3001]]) @@ -226,7 +226,7 @@ def func(x, y): results1 = func1(*params) results2 = func2(*params) - validate_measurements(measure_f, shots, results1, results2) + utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("postselect", [None, 0, 1]) @@ -255,11 +255,11 @@ def func(x, y, z): if isinstance(shots, Sequence): for s, r1, r2 in zip(shots, results1, results2): for _r1, _r2 in zip(r1, r2): - validate_measurements(measure_f, s, _r1, _r2) + utils.validate_measurements(measure_f, s, _r1, _r2) return for r1, r2 in zip(results1, results2): - validate_measurements(measure_f, shots, r1, r2) + utils.validate_measurements(measure_f, shots, r1, r2) @pytest.mark.parametrize( @@ -299,7 +299,7 @@ def func(x): results1 = func1(param) results2 = func2(param) - validate_measurements(measure_f, shots, results1, results2) + utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("shots", [None, 5000, [5000, 5001]]) @@ -341,7 +341,7 @@ def func(x): results1 = func1(param) results2 = func2(param) - validate_measurements(measure_f, shots, results1, results2) + utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("shots", [None, 5000, [5000, 5001]]) @@ -372,7 +372,7 @@ def func(x, y, z): results1 = func1(*params) results2 = func2(*params) - validate_measurements(measure_f, shots, results1, results2) + utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("shots", [7500, [5000, 5001]]) @@ -403,7 +403,7 @@ def func(x): results1 = func1(param) results2 = func2(param) - validate_measurements(measure_f, shots, results1, results2) + utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("shots", [5000]) @@ -435,7 +435,7 @@ def func(x, y): results1 = func1(*param) results2 = func2(*param) - validate_measurements(measure_f, shots, results1, results2) + utils.validate_measurements(measure_f, shots, results1, results2) grad1 = qml.grad(func)(*param) grad2 = qml.grad(func2)(*param) @@ -467,7 +467,7 @@ def func(x, y): results1 = func1(*param) results2 = func2(*param) - validate_measurements(measure_fn, shots, results1, results2, batch_size=2) + utils.validate_measurements(measure_fn, shots, results1, results2, batch_size=2) if measure_fn is qml.sample and postselect is None: for i in range(2): # batch_size @@ -524,7 +524,7 @@ def func(x, y): results1 = func1(*param) results2 = func2(*param) - validate_measurements(qml.sample, shots, results1, results2, batch_size=None) + utils.validate_measurements(qml.sample, shots, results1, results2, batch_size=None) evals = obs.eigvals() for eig in evals: @@ -688,4 +688,4 @@ def func(x): results1 = func1(param) results2 = func2(param) - validate_measurements(measure_f, shots, results1, results2) + utils.validate_measurements(measure_f, shots, results1, results2) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py new file mode 100644 index 00000000000..fab28c2bfca --- /dev/null +++ b/tests/helpers/utils.py @@ -0,0 +1,128 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Pytest helper functions are defined in this module. +""" +from functools import reduce +from typing import Iterable, Sequence + +import numpy as np + +import pennylane as qml + + +def validate_counts(shots, results1, results2, batch_size=None): + """Compares two counts. + + If the results are ``Sequence``s, loop over entries. + + Fails if a key of ``results1`` is not found in ``results2``. + Passes if counts are too low, chosen as ``100``. + Otherwise, fails if counts differ by more than ``20`` plus 20 percent. + """ + if isinstance(shots, Sequence): + assert isinstance(results1, tuple) + assert isinstance(results2, tuple) + assert len(results1) == len(results2) == len(shots) + for s, r1, r2 in zip(shots, results1, results2): + validate_counts(s, r1, r2, batch_size=batch_size) + return + + if batch_size is not None: + assert isinstance(results1, Iterable) + assert isinstance(results2, Iterable) + assert len(results1) == len(results2) == batch_size + for r1, r2 in zip(results1, results2): + validate_counts(shots, r1, r2, batch_size=None) + return + + for key1, val1 in results1.items(): + val2 = results2[key1] + if abs(val1 + val2) > 100: + assert np.allclose(val1, val2, atol=20, rtol=0.2) + + +def validate_samples(shots, results1, results2, batch_size=None): + """Compares two samples. + + If the results are ``Sequence``s, loop over entries. + + Fails if the results do not have the same shape, within ``20`` entries plus 20 percent. + This is to handle cases when post-selection yields variable shapes. + Otherwise, fails if the sums of samples differ by more than ``20`` plus 20 percent. + """ + if isinstance(shots, Sequence): + assert isinstance(results1, tuple) + assert isinstance(results2, tuple) + assert len(results1) == len(results2) == len(shots) + for s, r1, r2 in zip(shots, results1, results2): + validate_samples(s, r1, r2, batch_size=batch_size) + return + + if batch_size is not None: + assert isinstance(results1, Iterable) + assert isinstance(results2, Iterable) + assert len(results1) == len(results2) == batch_size + for r1, r2 in zip(results1, results2): + validate_samples(shots, r1, r2, batch_size=None) + return + + sh1, sh2 = results1.shape[0], results2.shape[0] + assert np.allclose(sh1, sh2, atol=20, rtol=0.2) + assert results1.ndim == results2.ndim + if results2.ndim > 1: + assert results1.shape[1] == results2.shape[1] + np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2) + + +def validate_expval(shots, results1, results2, batch_size=None): + """Compares two expval, probs or var. + + If the results are ``Sequence``s, validate the average of items. + + If ``shots is None``, validate using ``np.allclose``'s default parameters. + Otherwise, fails if the results do not match within ``0.01`` plus 20 percent. + """ + if isinstance(shots, Sequence): + assert isinstance(results1, tuple) + assert isinstance(results2, tuple) + assert len(results1) == len(results2) == len(shots) + results1 = reduce(lambda x, y: x + y, results1) / len(results1) + results2 = reduce(lambda x, y: x + y, results2) / len(results2) + validate_expval(sum(shots), results1, results2, batch_size=batch_size) + return + + if shots is None: + assert np.allclose(results1, results2) + return + + if batch_size is not None: + assert len(results1) == len(results2) == batch_size + for r1, r2 in zip(results1, results2): + validate_expval(shots, r1, r2, batch_size=None) + + assert np.allclose(results1, results2, atol=0.01, rtol=0.2) + + +def validate_measurements(func, shots, results1, results2, batch_size=None): + """Calls the correct validation function based on measurement type.""" + if func is qml.counts: + validate_counts(shots, results1, results2, batch_size=batch_size) + return + + if func is qml.sample: + validate_samples(shots, results1, results2, batch_size=batch_size) + return + + validate_expval(shots, results1, results2, batch_size=batch_size) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index e747d178351..affd1962114 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -17,8 +17,9 @@ # pylint: disable=import-outside-toplevel from unittest.mock import patch +import numpy as np import pytest -from conftest import validate_measurements +import utils import pennylane as qml from pennylane import numpy as np @@ -809,4 +810,4 @@ def ansatz(): if measure_f == qml.sample: results0 = results0[results0 != fill_in_value] results1 = results1[results1 != fill_in_value] - validate_measurements(measure_f, shots, results1, results0) + utils.validate_measurements(measure_f, shots, results1, results0) From 0bd74cbd6a646f05b375b5a8dc3924135dbefc3b Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Mon, 3 Jun 2024 16:30:33 -0400 Subject: [PATCH 13/26] Update pennylane/transforms/dynamic_one_shot.py Co-authored-by: Mudit Pandey --- pennylane/transforms/dynamic_one_shot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 58b1d784ec4..f927c9a0d2b 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -249,7 +249,7 @@ def measurement_with_no_shots(measurement): interface = qml.math.get_deep_interface(circuit.data) interface = "numpy" if interface == "builtins" else interface - active_jit = qml.compiler.active_compiler() + active_qjit = qml.compiler.active() all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)] n_mcms = len(all_mcms) From 7dd313d75146885eeaad671c0feb976a605514aa Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Mon, 3 Jun 2024 16:37:26 -0400 Subject: [PATCH 14/26] Update pennylane/transforms/dynamic_one_shot.py Co-authored-by: Mudit Pandey --- pennylane/transforms/dynamic_one_shot.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index f927c9a0d2b..43b22df9b7f 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -327,13 +327,11 @@ def gather_mcm_jit(circuit_measurement, measurement, is_valid): if not found: raise LookupError("MCM not found") meas = qml.math.squeeze(meas) + sum_valid = qml.math.sum(is_valid) + count_1 = qml.math.sum(meas * is_valid) if isinstance(circuit_measurement, CountsMP): - sum_valid = qml.math.sum(is_valid) - count_1 = qml.math.sum(meas * is_valid) return {0: sum_valid - count_1, 1: count_1} if isinstance(circuit_measurement, ProbabilityMP): - sum_valid = qml.math.sum(is_valid) - count_1 = qml.math.sum(meas * is_valid) counts = qml.math.array( [sum_valid - count_1, count_1], like=qml.math.get_deep_interface(is_valid) ) From 7fd6447ef413d9f9ca040c34860371112ab90ad8 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Mon, 3 Jun 2024 16:39:03 -0400 Subject: [PATCH 15/26] Rename --- pennylane/transforms/dynamic_one_shot.py | 62 ++++++++++++------------ 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 43b22df9b7f..8ace700969d 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -213,10 +213,10 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): new_operations.append(new_op) else: new_operations.append(op) - if isinstance(op, MidMeasureMP): - new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) - elif "MidCircuitMeasure" in str(type(op)): + if "MidCircuitMeasure" in str(type(op)): new_measurements.append(qml.sample(op.out_classical_tracers[0])) + elif isinstance(op, MidMeasureMP): + new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) return qml.tape.QuantumScript( new_operations, new_measurements, @@ -276,7 +276,7 @@ def measurement_with_no_shots(measurement): ) if interface != "jax" and m.mv and not has_valid: meas = measurement_with_no_shots(m) - elif m.mv and active_jit: + elif m.mv and active_qjit: meas = gather_mcm_jit(m, mcm_samples, is_valid) elif m.mv: meas = gather_mcm(m, mcm_samples, is_valid) @@ -290,7 +290,7 @@ def measurement_with_no_shots(measurement): # as it assumes all elements of the input are of builtin python types and not belonging # to any particular interface result = qml.math.array(result, like=interface) - if active_jit: + if active_qjit: if isinstance(m, CountsMP): normalized_meas.append( (result[0][0], qml.math.sum(result[1] * is_valid.reshape((-1, 1)), axis=0)) @@ -307,52 +307,52 @@ def measurement_with_no_shots(measurement): return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0] -def gather_mcm_jit(circuit_measurement, measurement, is_valid): - """Combines, gathers and normalizes several measurements with trivial measurement values - when the Catalyst compiler is active. +def gather_mcm_jit(measurement, samples, is_valid): + """Process MCM measurements when the Catalyst compiler is active. Args: - circuit_measurement (MeasurementProcess): measurement - measurement (TensorLike): measurement results - samples (List[dict]): Mid-circuit measurement samples + measurement (MeasurementProcess): measurement + samples (dict): Mid-circuit measurement samples + is_valid (TensorLike): Mask of valid samples Returns: TensorLike: The combined measurement outcome """ found, meas = False, None - for k, meas in measurement.items(): - if circuit_measurement.mv is k.out_classical_tracers[0]: + for k, meas in samples.items(): + if measurement.mv is k.out_classical_tracers[0]: found = True break if not found: raise LookupError("MCM not found") meas = qml.math.squeeze(meas) - sum_valid = qml.math.sum(is_valid) - count_1 = qml.math.sum(meas * is_valid) - if isinstance(circuit_measurement, CountsMP): + if isinstance(measurement, (CountsMP, ProbabilityMP)): + sum_valid = qml.math.sum(is_valid) + count_1 = qml.math.sum(meas * is_valid) + if isinstance(measurement, CountsMP): return {0: sum_valid - count_1, 1: count_1} - if isinstance(circuit_measurement, ProbabilityMP): + if isinstance(measurement, ProbabilityMP): counts = qml.math.array( [sum_valid - count_1, count_1], like=qml.math.get_deep_interface(is_valid) ) return counts / sum_valid - return gather_non_mcm(circuit_measurement, meas, is_valid) + return gather_non_mcm(measurement, meas, is_valid) -def gather_non_mcm(circuit_measurement, measurement, is_valid): - """Combines, gathers and normalizes several measurements with trivial measurement values. +def gather_non_mcm(circuit_measurement, measurements, is_valid): + """Combines, gathers and normalizes an array of terminal measurements. Args: - circuit_measurement (MeasurementProcess): measurement - measurement (TensorLike): measurement results - samples (List[dict]): Mid-circuit measurement samples + circuit_measurement (MeasurementProcess): Measurement + measurements (TensorLike): Stacked measurement results + is_valid (TensorLike): Mask of valid samples Returns: TensorLike: The combined measurement outcome """ if isinstance(circuit_measurement, CountsMP): tmp = Counter() - for i, d in enumerate(measurement): + for i, d in enumerate(measurements): tmp.update( dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items()) ) @@ -360,23 +360,23 @@ def gather_non_mcm(circuit_measurement, measurement, is_valid): tmp = Counter({k: v for k, v in tmp.items() if v > 0}) return dict(sorted(tmp.items())) if isinstance(circuit_measurement, ExpectationMP): - return qml.math.sum(measurement * is_valid) / qml.math.sum(is_valid) + return qml.math.sum(measurements * is_valid) / qml.math.sum(is_valid) if isinstance(circuit_measurement, ProbabilityMP): - return qml.math.sum(measurement * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum( + return qml.math.sum(measurements * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum( is_valid ) if isinstance(circuit_measurement, SampleMP): is_interface_jax = qml.math.get_deep_interface(is_valid) == "jax" - if is_interface_jax and measurement.ndim == 2: + if is_interface_jax and measurements.ndim == 2: is_valid = is_valid.reshape((-1, 1)) return ( - qml.math.where(is_valid, measurement, fill_in_value) + qml.math.where(is_valid, measurements, fill_in_value) if is_interface_jax - else measurement[is_valid] + else measurements[is_valid] ) # VarianceMP - expval = qml.math.sum(measurement * is_valid) / qml.math.sum(is_valid) - return qml.math.sum((measurement - expval) ** 2 * is_valid) / qml.math.sum(is_valid) + expval = qml.math.sum(measurements * is_valid) / qml.math.sum(is_valid) + return qml.math.sum((measurements - expval) ** 2 * is_valid) / qml.math.sum(is_valid) def gather_mcm(measurement, samples, is_valid): From b5c1415b5be173c7c8f5914fd6ba426de6807477 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Mon, 3 Jun 2024 20:41:03 +0000 Subject: [PATCH 16/26] Fix docstring. --- pennylane/transforms/dynamic_one_shot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 8ace700969d..95549e33555 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -385,6 +385,7 @@ def gather_mcm(measurement, samples, is_valid): Args: measurement (MeasurementProcess): measurement samples (List[dict]): Mid-circuit measurement samples + is_valid (TensorLike): Mask of valid samples Returns: TensorLike: The combined measurement outcome From ef4fc27981f4cf39fb822707029812b31c03aac8 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Mon, 3 Jun 2024 20:43:04 +0000 Subject: [PATCH 17/26] Rename gather_mcm_jit --- pennylane/transforms/dynamic_one_shot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 95549e33555..0f966c3a0d3 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -277,7 +277,7 @@ def measurement_with_no_shots(measurement): if interface != "jax" and m.mv and not has_valid: meas = measurement_with_no_shots(m) elif m.mv and active_qjit: - meas = gather_mcm_jit(m, mcm_samples, is_valid) + meas = gather_mcm_qjit(m, mcm_samples, is_valid) elif m.mv: meas = gather_mcm(m, mcm_samples, is_valid) elif interface != "jax" and not has_valid: @@ -307,7 +307,7 @@ def measurement_with_no_shots(measurement): return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0] -def gather_mcm_jit(measurement, samples, is_valid): +def gather_mcm_qjit(measurement, samples, is_valid): """Process MCM measurements when the Catalyst compiler is active. Args: From 646a626382bcde342b188e178a8b2f93b5b85185 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Mon, 3 Jun 2024 20:55:29 +0000 Subject: [PATCH 18/26] utils => mcm_utils. --- .../test_default_qubit_native_mcm.py | 30 +++++++++---------- tests/helpers/{utils.py => mcm_utils.py} | 0 tests/test_compiler.py | 4 +-- 3 files changed, 17 insertions(+), 17 deletions(-) rename tests/helpers/{utils.py => mcm_utils.py} (100%) diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index b3674d87414..a723f01407e 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -14,9 +14,9 @@ """Tests for default qubit preprocessing.""" from typing import Sequence +import mcm_utils import numpy as np import pytest -import utils import pennylane as qml from pennylane.devices.qubit.apply_operation import MidMeasureMP, apply_mid_measure @@ -123,7 +123,7 @@ def func(x, y): results1 = func1(*params) results2 = func2(*params) - utils.validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) # pylint: disable=unused-argument @@ -171,7 +171,7 @@ def func(x, y, z): results1 = func1(*params) results2 = func2(*params) - utils.validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("postselect", [None, 0, 1]) @@ -195,7 +195,7 @@ def func(x, y, z): results2 = func2(*params) for measure_f, res1, res2 in zip([qml.counts, qml.expval], results1, results2): - utils.validate_measurements(measure_f, 5000, res1, res2) + mcm_utils.validate_measurements(measure_f, 5000, res1, res2) @pytest.mark.parametrize("shots", [None, 3000, [3000, 3001]]) @@ -226,7 +226,7 @@ def func(x, y): results1 = func1(*params) results2 = func2(*params) - utils.validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("postselect", [None, 0, 1]) @@ -255,11 +255,11 @@ def func(x, y, z): if isinstance(shots, Sequence): for s, r1, r2 in zip(shots, results1, results2): for _r1, _r2 in zip(r1, r2): - utils.validate_measurements(measure_f, s, _r1, _r2) + mcm_utils.validate_measurements(measure_f, s, _r1, _r2) return for r1, r2 in zip(results1, results2): - utils.validate_measurements(measure_f, shots, r1, r2) + mcm_utils.validate_measurements(measure_f, shots, r1, r2) @pytest.mark.parametrize( @@ -299,7 +299,7 @@ def func(x): results1 = func1(param) results2 = func2(param) - utils.validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("shots", [None, 5000, [5000, 5001]]) @@ -341,7 +341,7 @@ def func(x): results1 = func1(param) results2 = func2(param) - utils.validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("shots", [None, 5000, [5000, 5001]]) @@ -372,7 +372,7 @@ def func(x, y, z): results1 = func1(*params) results2 = func2(*params) - utils.validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("shots", [7500, [5000, 5001]]) @@ -403,7 +403,7 @@ def func(x): results1 = func1(param) results2 = func2(param) - utils.validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) @pytest.mark.parametrize("shots", [5000]) @@ -435,7 +435,7 @@ def func(x, y): results1 = func1(*param) results2 = func2(*param) - utils.validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) grad1 = qml.grad(func)(*param) grad2 = qml.grad(func2)(*param) @@ -467,7 +467,7 @@ def func(x, y): results1 = func1(*param) results2 = func2(*param) - utils.validate_measurements(measure_fn, shots, results1, results2, batch_size=2) + mcm_utils.validate_measurements(measure_fn, shots, results1, results2, batch_size=2) if measure_fn is qml.sample and postselect is None: for i in range(2): # batch_size @@ -524,7 +524,7 @@ def func(x, y): results1 = func1(*param) results2 = func2(*param) - utils.validate_measurements(qml.sample, shots, results1, results2, batch_size=None) + mcm_utils.validate_measurements(qml.sample, shots, results1, results2, batch_size=None) evals = obs.eigvals() for eig in evals: @@ -688,4 +688,4 @@ def func(x): results1 = func1(param) results2 = func2(param) - utils.validate_measurements(measure_f, shots, results1, results2) + mcm_utils.validate_measurements(measure_f, shots, results1, results2) diff --git a/tests/helpers/utils.py b/tests/helpers/mcm_utils.py similarity index 100% rename from tests/helpers/utils.py rename to tests/helpers/mcm_utils.py diff --git a/tests/test_compiler.py b/tests/test_compiler.py index affd1962114..b9c45b80b7f 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -17,9 +17,9 @@ # pylint: disable=import-outside-toplevel from unittest.mock import patch +import mcm_utils import numpy as np import pytest -import utils import pennylane as qml from pennylane import numpy as np @@ -810,4 +810,4 @@ def ansatz(): if measure_f == qml.sample: results0 = results0[results0 != fill_in_value] results1 = results1[results1 != fill_in_value] - utils.validate_measurements(measure_f, shots, results1, results0) + mcm_utils.validate_measurements(measure_f, shots, results1, results0) From 152ce0744b9a250d319f7d06bc56f4ac4bf58bab Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Tue, 4 Jun 2024 19:58:14 +0000 Subject: [PATCH 19/26] Indent block. --- pennylane/transforms/dynamic_one_shot.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 0f966c3a0d3..2934baeb628 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -329,13 +329,13 @@ def gather_mcm_qjit(measurement, samples, is_valid): if isinstance(measurement, (CountsMP, ProbabilityMP)): sum_valid = qml.math.sum(is_valid) count_1 = qml.math.sum(meas * is_valid) - if isinstance(measurement, CountsMP): - return {0: sum_valid - count_1, 1: count_1} - if isinstance(measurement, ProbabilityMP): - counts = qml.math.array( - [sum_valid - count_1, count_1], like=qml.math.get_deep_interface(is_valid) - ) - return counts / sum_valid + if isinstance(measurement, CountsMP): + return {0: sum_valid - count_1, 1: count_1} + if isinstance(measurement, ProbabilityMP): + counts = qml.math.array( + [sum_valid - count_1, count_1], like=qml.math.get_deep_interface(is_valid) + ) + return counts / sum_valid return gather_non_mcm(measurement, meas, is_valid) From 6434d3124516f1bf506df6bfaa0eaf9684397032 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Tue, 4 Jun 2024 21:27:28 +0000 Subject: [PATCH 20/26] Fix MCMConfig default. --- pennylane/devices/execution_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 083f2880b6f..e8dd9816de1 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -14,7 +14,7 @@ """ Contains the :class:`ExecutionConfig` data class. """ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional, Union from pennylane.workflow import SUPPORTED_INTERFACES @@ -95,7 +95,7 @@ class ExecutionConfig: derivative_order: int = 1 """The derivative order to compute while evaluating a gradient""" - mcm_config: Union[MCMConfig, dict] = MCMConfig() + mcm_config: Union[MCMConfig, dict] = field(default_factory=MCMConfig) """Configuration options for handling mid-circuit measurements""" def __post_init__(self): From cba316abac191e3c86e0f8a541d9839526067872 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Wed, 5 Jun 2024 12:06:15 -0400 Subject: [PATCH 21/26] # pragma: no cover --- pennylane/transforms/dynamic_one_shot.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 2203203775b..226c1e0190f 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -19,7 +19,7 @@ # pylint: disable=import-outside-toplevel from collections import Counter -from typing import Callable, Sequence +from typing import Callable, Sequence, Tuple import numpy as np @@ -50,7 +50,7 @@ def null_postprocessing(results): @transform def dynamic_one_shot( tape: qml.tape.QuantumTape, **kwargs -) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: +) -> Tuple[Sequence[qml.tape.QuantumTape], Callable]: """Transform a QNode to into several one-shot tapes to support dynamic circuit execution. Args: @@ -208,13 +208,13 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): new_measurements.append(m) new_operations = [] for op in circuit.operations: - if "MidCircuitMeasure" in str(type(op)): + if "MidCircuitMeasure" in str(type(op)): # pragma: no cover new_op = op new_op.bypass_postselect = True new_operations.append(new_op) else: new_operations.append(op) - if "MidCircuitMeasure" in str(type(op)): + if "MidCircuitMeasure" in str(type(op)): # pragma: no cover new_measurements.append(qml.sample(op.out_classical_tracers[0])) elif isinstance(op, MidMeasureMP): new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) @@ -278,7 +278,7 @@ def measurement_with_no_shots(measurement): if interface != "jax" and m.mv and not has_valid: meas = measurement_with_no_shots(m) elif m.mv and active_qjit: - meas = gather_mcm_qjit(m, mcm_samples, is_valid) + meas = gather_mcm_qjit(m, mcm_samples, is_valid) # pragma: no cover elif m.mv: meas = gather_mcm(m, mcm_samples, is_valid) elif interface != "jax" and not has_valid: @@ -291,7 +291,7 @@ def measurement_with_no_shots(measurement): # as it assumes all elements of the input are of builtin python types and not belonging # to any particular interface result = qml.math.array(result, like=interface) - if active_qjit: + if active_qjit: # pragma: no cover if isinstance(m, CountsMP): normalized_meas.append( (result[0][0], qml.math.sum(result[1] * is_valid.reshape((-1, 1)), axis=0)) @@ -308,7 +308,7 @@ def measurement_with_no_shots(measurement): return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0] -def gather_mcm_qjit(measurement, samples, is_valid): +def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover """Process MCM measurements when the Catalyst compiler is active. Args: From 542f0076996738d2c77c8a4612105be974c61aea Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Wed, 5 Jun 2024 14:29:12 -0400 Subject: [PATCH 22/26] Update pennylane/transforms/dynamic_one_shot.py Co-authored-by: Mudit Pandey --- pennylane/transforms/dynamic_one_shot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 226c1e0190f..4dcfb5fd36f 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -50,7 +50,7 @@ def null_postprocessing(results): @transform def dynamic_one_shot( tape: qml.tape.QuantumTape, **kwargs -) -> Tuple[Sequence[qml.tape.QuantumTape], Callable]: +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Transform a QNode to into several one-shot tapes to support dynamic circuit execution. Args: From c2540fe6bcce5883132d25ecf5f3d0e07cc4ba81 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Wed, 5 Jun 2024 14:29:28 -0400 Subject: [PATCH 23/26] Update pennylane/transforms/dynamic_one_shot.py Co-authored-by: Mudit Pandey --- pennylane/transforms/dynamic_one_shot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 4dcfb5fd36f..a554e24a75d 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -19,7 +19,7 @@ # pylint: disable=import-outside-toplevel from collections import Counter -from typing import Callable, Sequence, Tuple +from typing import Callable, Sequence import numpy as np From 6ecb359fbecaf1b60283fb2de57ba1c259d5beaa Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Thu, 6 Jun 2024 14:59:35 +0000 Subject: [PATCH 24/26] Export is_mcm in transforms for use in tape module. --- pennylane/tape/tape.py | 10 ++-------- pennylane/transforms/__init__.py | 2 +- pennylane/transforms/dynamic_one_shot.py | 12 ++++++------ 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/pennylane/tape/tape.py b/pennylane/tape/tape.py index 9e7f3ff9cb5..450570ab2ca 100644 --- a/pennylane/tape/tape.py +++ b/pennylane/tape/tape.py @@ -19,13 +19,7 @@ from threading import RLock import pennylane as qml -from pennylane.measurements import ( - CountsMP, - MeasurementProcess, - MidMeasureMP, - ProbabilityMP, - SampleMP, -) +from pennylane.measurements import CountsMP, MeasurementProcess, ProbabilityMP, SampleMP from pennylane.operation import DecompositionUndefinedError, Operator, StatePrepBase from pennylane.pytrees import register_pytree from pennylane.queuing import AnnotatedQueue, QueuingManager, process_queue @@ -51,7 +45,7 @@ def _validate_computational_basis_sampling(tape): qubit-wise commutativity relation.""" measurements = tape.measurements n_meas = len(measurements) - n_mcms = sum(isinstance(op, MidMeasureMP) for op in tape.operations) + n_mcms = sum(qml.transforms.is_mcm(op) for op in tape.operations) non_comp_basis_sampling_obs = [] comp_basis_sampling_obs = [] comp_basis_indices = [] diff --git a/pennylane/transforms/__init__.py b/pennylane/transforms/__init__.py index 7df898f2ee5..f69d978040b 100644 --- a/pennylane/transforms/__init__.py +++ b/pennylane/transforms/__init__.py @@ -288,7 +288,7 @@ def circuit(x, y): from .decompositions import clifford_t_decomposition from .defer_measurements import defer_measurements -from .dynamic_one_shot import dynamic_one_shot +from .dynamic_one_shot import dynamic_one_shot, is_mcm from .sign_expand import sign_expand from .hamiltonian_expand import hamiltonian_expand, sum_expand from .split_non_commuting import split_non_commuting diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index a554e24a75d..88d5e5287e2 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -40,6 +40,12 @@ fill_in_value = np.iinfo(np.int32).min +def is_mcm(operation): + """Returns True if the operation is a mid-circuit measurement and False otherwise.""" + mcm = isinstance(operation, MidMeasureMP) + return mcm or "MidCircuitMeasure" in str(type(operation)) + + def null_postprocessing(results): """A postprocessing function returned by a transform that only converts the batch of results into a result for a single ``QuantumTape``. @@ -181,12 +187,6 @@ def _dynamic_one_shot_qnode(self, qnode, targs, tkwargs): return self.default_qnode_transform(qnode, targs, tkwargs) -def is_mcm(operation): - """Returns True if the operation is a mid-circuit measurement and False otherwise.""" - mcm = isinstance(operation, MidMeasureMP) - return mcm or "MidCircuitMeasure" in str(type(operation)) - - def init_auxiliary_tape(circuit: qml.tape.QuantumScript): """Creates an auxiliary circuit to perform one-shot mid-circuit measurement calculations. From d27a24402527919496d8bd5db3d9859b20c7c5f3 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Thu, 6 Jun 2024 15:41:21 +0000 Subject: [PATCH 25/26] Fix _validate_computational_basis_sampling? --- pennylane/tape/tape.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pennylane/tape/tape.py b/pennylane/tape/tape.py index 450570ab2ca..59f21dae5dd 100644 --- a/pennylane/tape/tape.py +++ b/pennylane/tape/tape.py @@ -62,10 +62,10 @@ def _validate_computational_basis_sampling(tape): for idx, (cb_obs, global_idx) in enumerate( zip(comp_basis_sampling_obs, comp_basis_indices) ): - if cb_obs.wires == empty_wires: - all_wires = qml.wires.Wires.all_wires([m.wires for m in measurements]) - break if global_idx < n_meas - n_mcms: + if cb_obs.wires == empty_wires: + all_wires = qml.wires.Wires.all_wires([m.wires for m in measurements]) + break all_wires.append(cb_obs.wires) if idx == len(comp_basis_sampling_obs) - 1: all_wires = qml.wires.Wires.all_wires(all_wires) From 8095c6fa3b1dc6c01decdfa4460a8bb76f1440e7 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Fri, 7 Jun 2024 14:01:22 +0000 Subject: [PATCH 26/26] Remove new_operations logic --- pennylane/transforms/dynamic_one_shot.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 88d5e5287e2..b05c1d583b0 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -206,20 +206,13 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): new_measurements.append(SampleMP(obs=m.obs)) else: new_measurements.append(m) - new_operations = [] for op in circuit.operations: - if "MidCircuitMeasure" in str(type(op)): # pragma: no cover - new_op = op - new_op.bypass_postselect = True - new_operations.append(new_op) - else: - new_operations.append(op) if "MidCircuitMeasure" in str(type(op)): # pragma: no cover new_measurements.append(qml.sample(op.out_classical_tracers[0])) elif isinstance(op, MidMeasureMP): new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) return qml.tape.QuantumScript( - new_operations, + circuit.operations, new_measurements, shots=[1] * circuit.shots.total_shots, trainable_params=circuit.trainable_params, @@ -292,6 +285,9 @@ def measurement_with_no_shots(measurement): # to any particular interface result = qml.math.array(result, like=interface) if active_qjit: # pragma: no cover + # `result` contains (bases, counts) need to return (basis, sum(counts)) where `is_valid` + # Any row of `result[0]` contains basis, so we return `result[0][0]` + # We return the sum of counts (`result[1]`) weighting by `is_valid`, which is `0` for invalid samples if isinstance(m, CountsMP): normalized_meas.append( (result[0][0], qml.math.sum(result[1] * is_valid.reshape((-1, 1)), axis=0)) @@ -329,14 +325,15 @@ def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover raise LookupError("MCM not found") meas = qml.math.squeeze(meas) if isinstance(measurement, (CountsMP, ProbabilityMP)): + interface = qml.math.get_deep_interface(is_valid) sum_valid = qml.math.sum(is_valid) count_1 = qml.math.sum(meas * is_valid) if isinstance(measurement, CountsMP): - return {0: sum_valid - count_1, 1: count_1} - if isinstance(measurement, ProbabilityMP): - counts = qml.math.array( - [sum_valid - count_1, count_1], like=qml.math.get_deep_interface(is_valid) + return qml.math.array([0, 1], like=interface), qml.math.array( + [sum_valid - count_1, count_1], like=interface ) + if isinstance(measurement, ProbabilityMP): + counts = qml.math.array([sum_valid - count_1, count_1], like=interface) return counts / sum_valid return gather_non_mcm(measurement, meas, is_valid)