Skip to content

Commit

Permalink
Fix qml.counts overflow error for more than 64 qubits (#5556)
Browse files Browse the repository at this point in the history
**Context:**
`qml.counts` was converting samples to integers and using f-strings to
format as binary strings, however, it caused wires > 64 to have
overflow.

Without the changes in this PR:
```python
import pennylane as qml
import numpy as np

np.random.seed(582)

num_wires = 70
wires = list(range(num_wires))

dev = qml.device("default.clifford", wires=wires, shots=100)

@qml.qnode(dev)
# @qml.defer_measurements
def circuit():
    [qml.PauliX(w) for w in wires]
    return qml.counts()

circuit()
```
```pycon
{'-000000000000000000000000000000000000000000000000000000000000000000001': tensor(100, requires_grad=True)}
```

With the changes:
```python
import pennylane as qml
import numpy as np

np.random.seed(582)

num_wires = 70
wires = list(range(num_wires))

dev = qml.device("default.clifford", wires=wires, shots=100)

@qml.qnode(dev)
# @qml.defer_measurements
def circuit():
    [qml.PauliX(w) for w in wires]
    return qml.counts()

circuit()
```
```pycon
{'1111111111111111111111111111111111111111111111111111111111111111111111': tensor(100, requires_grad=True)}
```

**Description of the Change:**
* Change `qml.counts` to convert directly from list of samples to binary
string without first converting to integers.
* Changed `QubitDevice._samples_to_counts` to cast to `int8` instead of
`int64`. This is only for casting samples, not converting to ints, as
`int8` is more memory efficient.

**Benefits:**
* `qml.counts` now works with arbitrary wires. This is useful for
devices like `default.clifford`, which can simulate hundreds of qubits.

**Possible Drawbacks:**

**Related GitHub Issues:**
#5537 , #5513

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
  • Loading branch information
mudit2812 and dwierichs committed Apr 23, 2024
1 parent 3ae019d commit 6d1fd42
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 20 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@

* `qml.counts` no longer returns negative samples when measuring 8 or more wires.
[(#5544)](https://github.com/PennyLaneAI/pennylane/pull/5544)
[(#5556)](https://github.com/PennyLaneAI/pennylane/pull/5556)

* The `dynamic_one_shot` transform now works with broadcasting.
[(#5473)](https://github.com/PennyLaneAI/pennylane/pull/5473)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,7 @@ def circuit(x):
if mp.obs is None and not isinstance(mp.mv, MeasurementValue):
# convert samples and outcomes (if using) from arrays to str for dict keys
samples = np.array([sample for sample in samples if not np.any(np.isnan(sample))])
samples = qml.math.cast_like(samples, qml.math.int64(0))
samples = qml.math.cast_like(samples, qml.math.int8(0))
samples = np.apply_along_axis(_sample_to_str, -1, samples)
batched_ndims = 3 # no observable was provided, batched samples will have shape (batch_size, shots, len(wires))
if mp.all_outcomes:
Expand Down
26 changes: 15 additions & 11 deletions pennylane/measurements/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def circuit(x):

if self.obs is None and not isinstance(self.mv, MeasurementValue):
# convert samples and outcomes (if using) from arrays to str for dict keys
batched_ndims = 3 # no observable was provided, batched samples will have shape (batch_size, shots, len(wires))

# remove nans
mask = qml.math.isnan(samples)
Expand All @@ -307,21 +308,24 @@ def circuit(x):
mask = np.logical_not(np.any(mask, axis=tuple(range(1, samples.ndim))))
samples = samples[mask, ...]

# convert to string
def convert(x):
return f"{x:0{num_wires}b}"
def convert(sample):
# convert array of ints to string
return "".join(str(s) for s in sample)

exp2 = 2 ** np.arange(num_wires - 1, -1, -1)
samples = np.einsum("...i,i", samples, exp2)
new_shape = samples.shape
samples = qml.math.cast_like(samples, qml.math.int64(0))
samples = list(map(convert, samples.ravel()))
samples = np.array(samples).reshape(new_shape)
new_shape = samples.shape[:-1]
# Flatten broadcasting axis
flattened_samples = np.reshape(samples, (-1, shape[-1])).astype(np.int8)
samples = list(map(convert, flattened_samples))
samples = np.reshape(np.array(samples), new_shape)

batched_ndims = 3 # no observable was provided, batched samples will have shape (batch_size, shots, len(wires))
if self.all_outcomes:

def convert_from_int(x):
# convert int to binary string
return f"{x:0{num_wires}b}"

num_wires = len(self.wires) if len(self.wires) > 0 else shape[-1]
outcomes = list(map(convert, range(2**num_wires)))
outcomes = list(map(convert_from_int, range(2**num_wires)))

elif self.all_outcomes:
# This also covers statistics for mid-circuit measurements manipulated using
Expand Down
2 changes: 1 addition & 1 deletion pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,5 +381,5 @@ def gather_mcm(measurement, samples):
meas_tmp = measurement.__class__(wires=wires)
new_measurement = meas_tmp.process_samples(mcm_samples, wire_order=wires)
if isinstance(measurement, CountsMP) and not use_as_is:
new_measurement = dict(sorted((int(x, 2), y) for x, y in new_measurement.items()))
new_measurement = dict(sorted((int(x), y) for x, y in new_measurement.items()))
return new_measurement
24 changes: 17 additions & 7 deletions tests/measurements/test_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,19 +145,29 @@ def test_counts_with_nan_samples(self):
total_counts = sum(count for count in result.values())
assert total_counts == 997

@pytest.mark.parametrize("n_wires", [5, 8, 10])
@pytest.mark.parametrize("all_outcomes", [True, False])
def test_counts_multi_wires_no_overflow(self, n_wires, all_outcomes):
@pytest.mark.parametrize("batch_size", [None, 1, 4])
@pytest.mark.parametrize("n_wires", [4, 10, 65])
@pytest.mark.parametrize("all_outcomes", [False, True])
def test_counts_multi_wires_no_overflow(self, n_wires, all_outcomes, batch_size):
"""Test that binary strings for wire samples are not negative due to overflow."""
if all_outcomes and n_wires == 65:
pytest.skip("Too much memory being used, skipping")
shots = 1000
total_wires = 10
samples = np.random.choice([0, 1], size=(shots, total_wires)).astype(np.float64)
total_wires = 65
shape = (batch_size, shots, total_wires) if batch_size else (shots, total_wires)
samples = np.random.choice([0, 1], size=shape).astype(np.float64)
result = qml.counts(wires=list(range(n_wires)), all_outcomes=all_outcomes).process_samples(
samples, wire_order=list(range(total_wires))
)

assert sum(result.values()) == shots
assert all(0 <= int(sample, 2) <= 2**n_wires for sample in result.keys())
if batch_size:
assert len(result) == batch_size
for r in result:
assert sum(r.values()) == shots
assert all("-" not in sample for sample in r.keys())
else:
assert sum(result.values()) == shots
assert all("-" not in sample for sample in result.keys())

def test_counts_obs(self):
"""Test that the counts function outputs counts of the right size for observables"""
Expand Down

0 comments on commit 6d1fd42

Please sign in to comment.