From 6d1fd420b4af0ea9653ac80d32c166f650247c4c Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 23 Apr 2024 17:03:16 -0400 Subject: [PATCH] Fix `qml.counts` overflow error for more than 64 qubits (#5556) **Context:** `qml.counts` was converting samples to integers and using f-strings to format as binary strings, however, it caused wires > 64 to have overflow. Without the changes in this PR: ```python import pennylane as qml import numpy as np np.random.seed(582) num_wires = 70 wires = list(range(num_wires)) dev = qml.device("default.clifford", wires=wires, shots=100) @qml.qnode(dev) # @qml.defer_measurements def circuit(): [qml.PauliX(w) for w in wires] return qml.counts() circuit() ``` ```pycon {'-000000000000000000000000000000000000000000000000000000000000000000001': tensor(100, requires_grad=True)} ``` With the changes: ```python import pennylane as qml import numpy as np np.random.seed(582) num_wires = 70 wires = list(range(num_wires)) dev = qml.device("default.clifford", wires=wires, shots=100) @qml.qnode(dev) # @qml.defer_measurements def circuit(): [qml.PauliX(w) for w in wires] return qml.counts() circuit() ``` ```pycon {'1111111111111111111111111111111111111111111111111111111111111111111111': tensor(100, requires_grad=True)} ``` **Description of the Change:** * Change `qml.counts` to convert directly from list of samples to binary string without first converting to integers. * Changed `QubitDevice._samples_to_counts` to cast to `int8` instead of `int64`. This is only for casting samples, not converting to ints, as `int8` is more memory efficient. **Benefits:** * `qml.counts` now works with arbitrary wires. This is useful for devices like `default.clifford`, which can simulate hundreds of qubits. **Possible Drawbacks:** **Related GitHub Issues:** #5537 , #5513 --------- Co-authored-by: David Wierichs --- doc/releases/changelog-dev.md | 1 + pennylane/_qubit_device.py | 2 +- pennylane/measurements/counts.py | 26 ++++++++++++++---------- pennylane/transforms/dynamic_one_shot.py | 2 +- tests/measurements/test_counts.py | 24 +++++++++++++++------- 5 files changed, 35 insertions(+), 20 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 563e74dd02a..1289ec0865f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -470,6 +470,7 @@ * `qml.counts` no longer returns negative samples when measuring 8 or more wires. [(#5544)](https://github.com/PennyLaneAI/pennylane/pull/5544) + [(#5556)](https://github.com/PennyLaneAI/pennylane/pull/5556) * The `dynamic_one_shot` transform now works with broadcasting. [(#5473)](https://github.com/PennyLaneAI/pennylane/pull/5473) diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py index ba8b031219a..dbb01c0e0a7 100644 --- a/pennylane/_qubit_device.py +++ b/pennylane/_qubit_device.py @@ -1462,7 +1462,7 @@ def circuit(x): if mp.obs is None and not isinstance(mp.mv, MeasurementValue): # convert samples and outcomes (if using) from arrays to str for dict keys samples = np.array([sample for sample in samples if not np.any(np.isnan(sample))]) - samples = qml.math.cast_like(samples, qml.math.int64(0)) + samples = qml.math.cast_like(samples, qml.math.int8(0)) samples = np.apply_along_axis(_sample_to_str, -1, samples) batched_ndims = 3 # no observable was provided, batched samples will have shape (batch_size, shots, len(wires)) if mp.all_outcomes: diff --git a/pennylane/measurements/counts.py b/pennylane/measurements/counts.py index f8539665187..7a99ae84cd4 100644 --- a/pennylane/measurements/counts.py +++ b/pennylane/measurements/counts.py @@ -299,6 +299,7 @@ def circuit(x): if self.obs is None and not isinstance(self.mv, MeasurementValue): # convert samples and outcomes (if using) from arrays to str for dict keys + batched_ndims = 3 # no observable was provided, batched samples will have shape (batch_size, shots, len(wires)) # remove nans mask = qml.math.isnan(samples) @@ -307,21 +308,24 @@ def circuit(x): mask = np.logical_not(np.any(mask, axis=tuple(range(1, samples.ndim)))) samples = samples[mask, ...] - # convert to string - def convert(x): - return f"{x:0{num_wires}b}" + def convert(sample): + # convert array of ints to string + return "".join(str(s) for s in sample) - exp2 = 2 ** np.arange(num_wires - 1, -1, -1) - samples = np.einsum("...i,i", samples, exp2) - new_shape = samples.shape - samples = qml.math.cast_like(samples, qml.math.int64(0)) - samples = list(map(convert, samples.ravel())) - samples = np.array(samples).reshape(new_shape) + new_shape = samples.shape[:-1] + # Flatten broadcasting axis + flattened_samples = np.reshape(samples, (-1, shape[-1])).astype(np.int8) + samples = list(map(convert, flattened_samples)) + samples = np.reshape(np.array(samples), new_shape) - batched_ndims = 3 # no observable was provided, batched samples will have shape (batch_size, shots, len(wires)) if self.all_outcomes: + + def convert_from_int(x): + # convert int to binary string + return f"{x:0{num_wires}b}" + num_wires = len(self.wires) if len(self.wires) > 0 else shape[-1] - outcomes = list(map(convert, range(2**num_wires))) + outcomes = list(map(convert_from_int, range(2**num_wires))) elif self.all_outcomes: # This also covers statistics for mid-circuit measurements manipulated using diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 98625aa6042..2480bcff9aa 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -381,5 +381,5 @@ def gather_mcm(measurement, samples): meas_tmp = measurement.__class__(wires=wires) new_measurement = meas_tmp.process_samples(mcm_samples, wire_order=wires) if isinstance(measurement, CountsMP) and not use_as_is: - new_measurement = dict(sorted((int(x, 2), y) for x, y in new_measurement.items())) + new_measurement = dict(sorted((int(x), y) for x, y in new_measurement.items())) return new_measurement diff --git a/tests/measurements/test_counts.py b/tests/measurements/test_counts.py index 085bccd3c49..afdba1ed758 100644 --- a/tests/measurements/test_counts.py +++ b/tests/measurements/test_counts.py @@ -145,19 +145,29 @@ def test_counts_with_nan_samples(self): total_counts = sum(count for count in result.values()) assert total_counts == 997 - @pytest.mark.parametrize("n_wires", [5, 8, 10]) - @pytest.mark.parametrize("all_outcomes", [True, False]) - def test_counts_multi_wires_no_overflow(self, n_wires, all_outcomes): + @pytest.mark.parametrize("batch_size", [None, 1, 4]) + @pytest.mark.parametrize("n_wires", [4, 10, 65]) + @pytest.mark.parametrize("all_outcomes", [False, True]) + def test_counts_multi_wires_no_overflow(self, n_wires, all_outcomes, batch_size): """Test that binary strings for wire samples are not negative due to overflow.""" + if all_outcomes and n_wires == 65: + pytest.skip("Too much memory being used, skipping") shots = 1000 - total_wires = 10 - samples = np.random.choice([0, 1], size=(shots, total_wires)).astype(np.float64) + total_wires = 65 + shape = (batch_size, shots, total_wires) if batch_size else (shots, total_wires) + samples = np.random.choice([0, 1], size=shape).astype(np.float64) result = qml.counts(wires=list(range(n_wires)), all_outcomes=all_outcomes).process_samples( samples, wire_order=list(range(total_wires)) ) - assert sum(result.values()) == shots - assert all(0 <= int(sample, 2) <= 2**n_wires for sample in result.keys()) + if batch_size: + assert len(result) == batch_size + for r in result: + assert sum(r.values()) == shots + assert all("-" not in sample for sample in r.keys()) + else: + assert sum(result.values()) == shots + assert all("-" not in sample for sample in result.keys()) def test_counts_obs(self): """Test that the counts function outputs counts of the right size for observables"""