diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3a98cce187d..1f4ba35cd86 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -279,6 +279,9 @@ * Removed the warning that an observable might not be hermitian in `qnode` executions. This enables jit-compilation. [(#5506)](https://github.com/PennyLaneAI/pennylane/pull/5506) +* Implement `Shots.bins()` method. + [(#5476)](https://github.com/PennyLaneAI/pennylane/pull/5476) +

Breaking changes 💔

* Operator dunder methods now combine like-operator arithmetic classes via `lazy=False`. This reduces the chance of `RecursionError` and makes nested diff --git a/pennylane/devices/qubit/sampling.py b/pennylane/devices/qubit/sampling.py index 1daeb2a6fae..1044de957c5 100644 --- a/pennylane/devices/qubit/sampling.py +++ b/pennylane/devices/qubit/sampling.py @@ -268,31 +268,6 @@ def _process_single_shot(samples): return tuple(processed) - # if there is a shot vector, build a list containing results for each shot entry - if shots.has_partitioned_shots: - processed_samples = [] - for s in shots: - # currently we call sample_state for each shot entry, but it may be - # better to call sample_state just once with total_shots, then use - # the shot_range keyword argument - try: - samples = sample_state( - state, - shots=s, - is_state_batched=is_state_batched, - wires=wires, - rng=rng, - prng_key=prng_key, - ) - except ValueError as e: - if str(e) != "probabilities contain NaN": - raise e - samples = qml.math.full((s, len(wires)), 0) - - processed_samples.append(_process_single_shot(samples)) - - return tuple(zip(*processed_samples)) - try: samples = sample_state( state, @@ -307,7 +282,15 @@ def _process_single_shot(samples): raise e samples = qml.math.full((shots.total_shots, len(wires)), 0) - return _process_single_shot(samples) + processed_samples = [] + for lower, upper in shots.bins(): + shot = _process_single_shot(samples[..., lower:upper, :]) + processed_samples.append(shot) + + if shots.has_partitioned_shots: + return tuple(zip(*processed_samples)) + + return processed_samples[0] def _measure_classical_shadow( diff --git a/pennylane/measurements/shots.py b/pennylane/measurements/shots.py index 448f796e732..9362f6fe792 100644 --- a/pennylane/measurements/shots.py +++ b/pennylane/measurements/shots.py @@ -261,3 +261,20 @@ def has_partitioned_shots(self): def num_copies(self): """The total number of copies of any shot quantity.""" return sum(s.copies for s in self.shot_vector) + + def bins(self): + """ + Yields: + tuple: A tuple containing the lower and upper bounds for each shot quantity in shot_vector. + + Example: + >>> shots = Shots((1, 1, 2, 3)) + >>> list(shots.bins()) + [(0,1), (1,2), (2,4), (4,7)] + """ + lower_bound = 0 + for sc in self.shot_vector: + for _ in range(sc.copies): + upper_bound = lower_bound + sc.shots + yield lower_bound, upper_bound + lower_bound = upper_bound diff --git a/tests/devices/qubit/test_sampling.py b/tests/devices/qubit/test_sampling.py index cb1455ab3da..8238cd65605 100644 --- a/tests/devices/qubit/test_sampling.py +++ b/tests/devices/qubit/test_sampling.py @@ -856,7 +856,7 @@ def test_nonsample_measure_shot_vector(self, shots, measurement, expected): r = r[0] assert r.shape == expected.shape - assert np.allclose(r, expected, atol=0.01) + assert np.allclose(r, expected, atol=0.02) @pytest.mark.jax @@ -1071,7 +1071,7 @@ def test_nonsample_measure_shot_vector(self, mocker, shots, measurement, expecte r = r[0] assert r.shape == expected.shape - assert np.allclose(r, expected, atol=0.01) + assert np.allclose(r, expected, atol=0.03) class TestHamiltonianSamples: diff --git a/tests/measurements/test_shots.py b/tests/measurements/test_shots.py index 7aebda3db6d..cad4e8b390a 100644 --- a/tests/measurements/test_shots.py +++ b/tests/measurements/test_shots.py @@ -283,3 +283,23 @@ def test_shots_rmul(self): scaled_sh1 = 2 * sh1 rev_scaled_sh1 = sh1 * 2 assert scaled_sh1.total_shots == rev_scaled_sh1.total_shots + + +class TestShotsBins: + """Tests Shots.bins() method.""" + + def test_when_shots_is_none(self): + """Tests that the method returns no bins when shots is None.""" + shots = Shots(None) + assert not list(shots.bins()) + + def test_when_shots_is_int(self): + """Tests that the method returns the correct bins when shots is an int.""" + shots = Shots(10) + assert list(shots.bins()) == [(0, 10)] + + @pytest.mark.parametrize("sequence", [[1, 1, 3, 4], [(1, 2), 3, 4]]) + def test_when_shots_is_sequence_with_copies(self, sequence): + """Tests that the method returns the correct bins when shots is a sequence with copies.""" + shots = Shots(sequence) + assert list(shots.bins()) == [(0, 1), (1, 2), (2, 5), (5, 9)]