From 59a1e0586e707d057a0c92d4239036afa5312b73 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 24 May 2024 18:54:36 -0400 Subject: [PATCH] Improve support for Torch and Jax with `dynamic_one_shot` (#5672) **Context:** Opened in favour of #5630. Bug fix for #5442. This PR updates `dynamic_one_shot` so that it has better compatibility with the `torch` and `jax` interfaces. **Description of the Change:** * Change casting method from `array.astype()` to `qml.math.cast` in the `apply_operation` dispatch for `MidMeasureMP`. * Update usage of `qml.math` in `dynamic_one_shot`. * When using `qml.counts`, cast results to ints before converting to strings for lists of MCM values and floats for single MCM values. This is needed because jax arrays are not hashable, and the hash of torch tensors seems to be independent of the value(s) stored inside it. Thus, neither can be used as keys for dictionaries. **Benefits:** Better interface support with `dynamic_one_shot`. **Possible Drawbacks:** **Related GitHub Issues:** --------- Co-authored-by: Vincent Michaud-Rioux Co-authored-by: Vincent Michaud-Rioux Co-authored-by: Christina Lee Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Astral Cai Co-authored-by: David Wierichs Co-authored-by: lillian542 <38584660+lillian542@users.noreply.github.com> Co-authored-by: Pietropaolo Frisoni Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com> Co-authored-by: Jay Soni Co-authored-by: Guillermo Alonso-Linaje <65235481+KetpuntoG@users.noreply.github.com> Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com> Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com> Co-authored-by: Diksha Dhawan <40900030+ddhawan11@users.noreply.github.com> Co-authored-by: Isaac De Vlugt Co-authored-by: Diego <67476785+DSGuala@users.noreply.github.com> Co-authored-by: trbromley Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com> Co-authored-by: David Ittah Co-authored-by: soranjh <40344468+soranjh@users.noreply.github.com> --- doc/releases/changelog-dev.md | 4 + pennylane/devices/qubit/apply_operation.py | 6 +- pennylane/devices/qubit/simulate.py | 2 +- pennylane/math/single_dispatch.py | 1 + pennylane/transforms/dynamic_one_shot.py | 28 ++- .../test_default_qubit_native_mcm.py | 55 +++++- tests/transforms/test_dynamic_one_shot.py | 165 ++++++++++++++++++ 7 files changed, 244 insertions(+), 17 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3690638820b..3b03629e105 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -156,6 +156,9 @@

Bug fixes 🐛

+* The `dynamic_one_shot` transform now has expanded support for the `jax` and `torch` interfaces. + [(#5672)](https://github.com/PennyLaneAI/pennylane/pull/5672) + * The decomposition of `StronglyEntanglingLayers` is now compatible with broadcasting. [(#5716)](https://github.com/PennyLaneAI/pennylane/pull/5716) @@ -213,5 +216,6 @@ Korbinian Kottmann, Christina Lee, Vincent Michaud-Rioux, Lee James O'Riordan, +Mudit Pandey, Kenya Sakka, David Wierichs. diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index 9b42e942c74..ae0c9117c38 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -330,8 +330,10 @@ def binomial_fn(n, p): # to reset enables jax.jit and prevents it from using Python callbacks element = op.reset and sample == 1 matrix = qml.math.array( - [[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]], like=interface - ).astype(float) + [[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]], + like=interface, + dtype=float, + ) state = apply_operation( qml.QubitUnitary(matrix, wire), state, is_state_batched=is_state_batched, debugger=debugger ) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index f1929cd83db..7218449c625 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -287,7 +287,7 @@ def simulate( trainable_params=circuit.trainable_params, ) keys = jax_random_split(prng_key, num=circuit.shots.total_shots) - if qml.math.get_deep_interface(circuit.data) == "jax": + if qml.math.get_deep_interface(circuit.data) == "jax" and prng_key is not None: # pylint: disable=import-outside-toplevel import jax diff --git a/pennylane/math/single_dispatch.py b/pennylane/math/single_dispatch.py index 900d90d98dc..7f69ee82b11 100644 --- a/pennylane/math/single_dispatch.py +++ b/pennylane/math/single_dispatch.py @@ -242,6 +242,7 @@ def _take_autograd(tensor, indices, axis=None): ar.autoray._SUBMODULE_ALIASES["tensorflow", "isclose"] = "tensorflow.experimental.numpy" ar.autoray._SUBMODULE_ALIASES["tensorflow", "atleast_1d"] = "tensorflow.experimental.numpy" ar.autoray._SUBMODULE_ALIASES["tensorflow", "all"] = "tensorflow.experimental.numpy" +ar.autoray._SUBMODULE_ALIASES["tensorflow", "ravel"] = "tensorflow.experimental.numpy" ar.autoray._SUBMODULE_ALIASES["tensorflow", "vstack"] = "tensorflow.experimental.numpy" tf_fft_functions = [ diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 109f43db30b..998cefa93c5 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -228,6 +228,7 @@ def measurement_with_no_shots(measurement): ) interface = qml.math.get_deep_interface(circuit.data) + interface = "numpy" if interface == "builtins" else interface all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)] n_mcms = len(all_mcms) @@ -243,10 +244,13 @@ def measurement_with_no_shots(measurement): mcm_samples = qml.math.array( [[res] if single_measurement else res[-n_mcms::] for res in results], like=interface ) - has_postselect = qml.math.array([op.postselect is not None for op in all_mcms]).reshape((1, -1)) + # Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1 + has_postselect = qml.math.array( + [[int(op.postselect is not None) for op in all_mcms]], like=interface + ) postselect = qml.math.array( - [0 if op.postselect is None else op.postselect for op in all_mcms] - ).reshape((1, -1)) + [[0 if op.postselect is None else op.postselect for op in all_mcms]], like=interface + ) is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1) has_valid = qml.math.any(is_valid) mid_meas = [op for op in circuit.operations if is_mcm(op)] @@ -268,7 +272,12 @@ def measurement_with_no_shots(measurement): meas = measurement_with_no_shots(m) m_count += 1 else: - result = qml.math.array([res[m_count] for res in results], like=interface) + result = [res[m_count] for res in results] + if not isinstance(m, CountsMP): + # We don't need to cast to arrays when using qml.counts. qml.math.array is not viable + # as it assumes all elements of the input are of builtin python types and not belonging + # to any particular interface + result = qml.math.stack(result, like=interface) meas = gather_non_mcm(m, result, is_valid) m_count += 1 if isinstance(m, SampleMP): @@ -292,7 +301,9 @@ def gather_non_mcm(circuit_measurement, measurement, is_valid): if isinstance(circuit_measurement, CountsMP): tmp = Counter() for i, d in enumerate(measurement): - tmp.update(dict((k, v * is_valid[i]) for k, v in d.items())) + tmp.update( + dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items()) + ) tmp = Counter({k: v for k, v in tmp.items() if v > 0}) return dict(sorted(tmp.items())) if isinstance(circuit_measurement, ExpectationMP): @@ -341,14 +352,13 @@ def gather_mcm(measurement, samples, is_valid): counts = qml.math.array(counts, like=interface) return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): - mcm_samples = [{"".join(str(v) for v in tuple(s)): 1} for s in mcm_samples] + mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples] return gather_non_mcm(measurement, mcm_samples, is_valid) + mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface)) if isinstance(measurement, ProbabilityMP): - mcm_samples = qml.math.array(mv.concretize(samples), like=interface).ravel() counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())] counts = qml.math.array(counts, like=interface) return counts / qml.math.sum(counts) - mcm_samples = qml.math.array([mv.concretize(samples)], like=interface).ravel() if isinstance(measurement, CountsMP): - mcm_samples = [{s: 1} for s in mcm_samples] + mcm_samples = [{float(s): 1} for s in mcm_samples] return gather_non_mcm(measurement, mcm_samples, is_valid) diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index 4848ea4f6e3..e2184874cec 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for default qubit preprocessing.""" -from functools import partial, reduce +from functools import reduce from typing import Iterable, Sequence import numpy as np @@ -24,7 +24,11 @@ pytestmark = pytest.mark.slow -get_device = partial(qml.device, name="default.qubit", seed=8237945) + +def get_device(**kwargs): + kwargs.setdefault("shots", None) + kwargs.setdefault("seed", 8237945) + return qml.device("default.qubit", **kwargs) def validate_counts(shots, results1, results2, batch_size=None): @@ -88,7 +92,7 @@ def validate_samples(shots, results1, results2, batch_size=None): assert results1.ndim == results2.ndim if results2.ndim > 1: assert results1.shape[1] == results2.shape[1] - np.allclose(np.sum(results1), np.sum(results2), atol=20, rtol=0.2) + np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2) def validate_expval(shots, results1, results2, batch_size=None): @@ -611,7 +615,7 @@ def test_sample_with_prng_key(shots, postselect, reset): # pylint: disable=import-outside-toplevel from jax.random import PRNGKey - dev = qml.device("default.qubit", shots=shots, seed=PRNGKey(678)) + dev = get_device(shots=shots, seed=PRNGKey(678)) param = [np.pi / 4, np.pi / 3] obs = qml.PauliZ(0) @ qml.PauliZ(1) @@ -659,7 +663,7 @@ def test_jax_jit(diff_method, postselect, reset): shots = 10 - dev = qml.device("default.qubit", shots=shots, seed=jax.random.PRNGKey(678)) + dev = get_device(shots=shots, seed=jax.random.PRNGKey(678)) params = [np.pi / 2.5, np.pi / 3, -np.pi / 3.5] obs = qml.PauliY(0) @@ -750,3 +754,44 @@ def func(x): results2 = func2(param) for r1, r2 in zip(results1.keys(), results2.keys()): assert r1 == r2 + + +@pytest.mark.torch +@pytest.mark.parametrize("postselect", [None, 1]) +@pytest.mark.parametrize("diff_method", [None, "best"]) +@pytest.mark.parametrize("measure_f", [qml.probs, qml.sample, qml.expval, qml.var]) +@pytest.mark.parametrize("meas_obj", [qml.PauliZ(1), [0, 1], "composite_mcm", "mcm_list"]) +def test_torch_integration(postselect, diff_method, measure_f, meas_obj): + """Test that native MCM circuits are executed correctly with Torch""" + if measure_f in (qml.var, qml.expval) and ( + isinstance(meas_obj, list) or meas_obj == "mcm_list" + ): + pytest.skip("Can't use wires/mcm lists with var or expval") + + import torch + + shots = 7000 + dev = get_device(shots=shots, seed=123456789) + param = torch.tensor(np.pi / 3, dtype=torch.float64) + + @qml.qnode(dev, diff_method=diff_method) + def func(x): + qml.RX(x, 0) + m0 = qml.measure(0) + qml.RX(0.5 * x, 1) + m1 = qml.measure(1, postselect=postselect) + qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0) + m2 = qml.measure(0) + + mid_measure = 0.5 * m2 if meas_obj == "composite_mcm" else [m1, m2] + measurement_key = "wires" if isinstance(meas_obj, list) else "op" + measurement_value = mid_measure if isinstance(meas_obj, str) else meas_obj + return measure_f(**{measurement_key: measurement_value}) + + func1 = func + func2 = qml.defer_measurements(func) + + results1 = func1(param) + results2 = func2(param) + + validate_measurements(measure_f, shots, results1, results2) diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index 3ca186171fe..ce9d67a429b 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -155,3 +155,168 @@ def test_len_measurements_mcms(measure, aux_measure, n_meas): assert len(aux_tape.measurements) == n_meas + n_mcms assert isinstance(aux_tape.measurements[0], aux_measure) assert all(isinstance(m, SampleMP) for m in aux_tape.measurements[1:]) + + +def assert_results(res, shots, n_mcms): + """Helper to check that expected raw results of executing the transformed tape are correct""" + assert len(res) == shots + # One for the non-MeasurementValue MP, and the rest of the mid-circuit measurements + assert all(len(r) == n_mcms + 1 for r in res) + # Not validating distribution of results as device sampling unit tests already validate + # that samples are generated correctly. + + +@pytest.mark.jax +@pytest.mark.parametrize("measure_f", (qml.expval, qml.probs, qml.sample, qml.var)) +@pytest.mark.parametrize("shots", [20, [20, 21]]) +@pytest.mark.parametrize("n_mcms", [1, 3]) +def test_tape_results_jax(shots, n_mcms, measure_f): + """Test that the simulation results of a tape are correct with jax parameters""" + import jax + + dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123)) + param = jax.numpy.array(np.pi / 2) + + mv = qml.measure(0) + mp = mv.measurements[0] + + tape = qml.tape.QuantumScript( + [qml.RX(param, 0), mp] + [MidMeasureMP(0, id=str(i)) for i in range(n_mcms - 1)], + [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)], + shots=shots, + ) + + tapes, _ = qml.dynamic_one_shot(tape) + results = dev.execute(tapes)[0] + + # The transformed tape never has a shot vector + if isinstance(shots, list): + shots = sum(shots) + + assert_results(results, shots, n_mcms) + + +@pytest.mark.jax +@pytest.mark.parametrize( + "measure_f, expected1, expected2", + [ + (qml.expval, 1.0, 1.0), + (qml.probs, [1, 0], [0, 1]), + (qml.sample, 1, 1), + (qml.var, 0.0, 0.0), + ], +) +@pytest.mark.parametrize("shots", [20, [20, 21]]) +@pytest.mark.parametrize("n_mcms", [1, 3]) +def test_jax_results_processing(shots, n_mcms, measure_f, expected1, expected2): + """Test that the results of tapes are processed correctly for tapes with jax parameters""" + import jax.numpy as jnp + + mv = qml.measure(0) + mp = mv.measurements[0] + + tape = qml.tape.QuantumScript( + [qml.RX(1.5, 0), mp] + [MidMeasureMP(0)] * (n_mcms - 1), + [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)], + shots=shots, + ) + _, fn = qml.dynamic_one_shot(tape) + all_shots = sum(shots) if isinstance(shots, list) else shots + + first_res = jnp.array([1.0, 0.0]) if measure_f == qml.probs else jnp.array(1.0) + rest = jnp.array(1, dtype=int) + single_shot_res = (first_res,) + (rest,) * n_mcms + # Raw results for each shot are (sample_for_first_measurement,) + (sample for 1st MCM, sample for 2nd MCM, ...) + raw_results = (single_shot_res,) * all_shots + raw_results = (raw_results,) + res = fn(raw_results) + + if measure_f is qml.sample: + # All samples 1 + expected1 = ( + [[expected1] * s for s in shots] if isinstance(shots, list) else [expected1] * shots + ) + expected2 = ( + [[expected2] * s for s in shots] if isinstance(shots, list) else [expected2] * shots + ) + else: + expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1 + expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2 + + if isinstance(shots, list): + assert len(res) == len(shots) + for r, e1, e2 in zip(res, expected1, expected2): + # Expected result is 2-list since we have two measurements in the tape + assert qml.math.allclose(r, [e1, e2]) + else: + # Expected result is 2-list since we have two measurements in the tape + assert qml.math.allclose(res, [expected1, expected2]) + + +@pytest.mark.jax +@pytest.mark.parametrize( + "measure_f, expected1, expected2", + [ + (qml.expval, 1.0, 1.0), + (qml.probs, [1, 0], [0, 1]), + (qml.sample, 1, 1), + (qml.var, 0.0, 0.0), + ], +) +@pytest.mark.parametrize("shots", [20, [20, 22]]) +def test_jax_results_postselection_processing(shots, measure_f, expected1, expected2): + """Test that the results of tapes are processed correctly for tapes with jax parameters + when postselecting""" + import jax.numpy as jnp + + param = jnp.array(np.pi / 2) + fill_value = np.iinfo(np.int32).min + mv = qml.measure(0, postselect=1) + mp = mv.measurements[0] + + tape = qml.tape.QuantumScript( + [qml.RX(param, 0), mp, MidMeasureMP(0)], + [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)], + shots=shots, + ) + _, fn = qml.dynamic_one_shot(tape) + all_shots = sum(shots) if isinstance(shots, list) else shots + + # Alternating tuple. Only the values at odd indices are valid + first_res_two_shot = ( + (jnp.array([1.0, 0.0]), jnp.array([0.0, 1.0])) + if measure_f == qml.probs + else (jnp.array(1.0), jnp.array(0.0)) + ) + first_res = first_res_two_shot * (all_shots // 2) + # Tuple of alternating 1s and 0s. Zero is invalid as postselecting on 1 + postselect_res = (jnp.array(1, dtype=int), jnp.array(0, dtype=int)) * (all_shots // 2) + rest = (jnp.array(1, dtype=int),) * all_shots + # Raw results for each shot are (sample_for_first_measurement, sample for 1st MCM, sample for 2nd MCM) + raw_results = tuple(zip(first_res, postselect_res, rest)) + raw_results = (raw_results,) + res = fn(raw_results) + + if measure_f is qml.sample: + expected1 = ( + [[expected1, fill_value] * (s // 2) for s in shots] + if isinstance(shots, list) + else [expected1, fill_value] * (shots // 2) + ) + expected2 = ( + [[expected2, fill_value] * (s // 2) for s in shots] + if isinstance(shots, list) + else [expected2, fill_value] * (shots // 2) + ) + else: + expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1 + expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2 + + if isinstance(shots, list): + assert len(res) == len(shots) + for r, e1, e2 in zip(res, expected1, expected2): + # Expected result is 2-list since we have two measurements in the tape + assert qml.math.allclose(r, [e1, e2]) + else: + # Expected result is 2-list since we have two measurements in the tape + assert qml.math.allclose(res, [expected1, expected2])