From bc45606263fcf58959440a535a4373555e369dc4 Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:31:07 +0000 Subject: [PATCH] Bitarray postselect (#12693) (#12836) * define BitArray.postselect() * add test for BitArray.postselect() * lint * remove redundant docstring text * Update qiskit/primitives/containers/bit_array.py Co-authored-by: Ian Hincks * docstring ticks (BitArray.postselect()) Co-authored-by: Ian Hincks * Simpler tests for BitArray.postselect * lint * add release note * check postselect() arg lengths match * fix postselect tests - fix bugs with checking that ValueError is raised. - addtionally run all tests on a "flat" data input * lint * Fix type-hint We immediately check the lengths of these args, so they should be Sequences, not Iterables. * remove spurious print() * lint * lint * use bitwise operations for faster postselect - Also added support for negative indices - Also updated tests * remove spurious print() * end final line of release note * try to fix docstring formatting * fix bitarray test assertion Co-authored-by: Takashi Imamichi <31178928+t-imamichi@users.noreply.github.com> * disallow postselect positional kwarg Co-authored-by: Takashi Imamichi <31178928+t-imamichi@users.noreply.github.com> * fix numpy dtype args * Simpler kwarg: "assume_unique" * lint (line too long) * simplification: remove assume_unique kwarg * improve misleading comment * raise IndexError if indices out of range - Change ValueError to IndexError. - Add check for out-of-range negative indices. - Simplify use of mod - Update test conditions (include checks for off-by-one errors) * lint * add negative-contradiction test * Update docstring with IndexErrors * lint * change slice_bits error from ValueError to IndexError * update slice_bits test to use IndexError * change ValueError to IndexError in slice_shots also update tests for this error * update error type in slice_shots docstring * Revert ValueError to IndexError changes Reverting these changes as they will instead be made in a separate PR. This reverts commit 8f3217838c6632d30ef300445fcca1590454b536. Revert "update error type in slice_shots docstring" This reverts commit 50545efbf26f6fac72c7c00919ae0995b8464ba6. Revert "change ValueError to IndexError in slice_shots" This reverts commit c4becd9b0e4363797331157b176c1603dabc40c9. Revert "update slice_bits test to use IndexError" This reverts commit c2b00390da40b12d5821476de850803da2b72a69. * fix docstring formatting Co-authored-by: Takashi Imamichi <31178928+t-imamichi@users.noreply.github.com> * allow selection to be int instead of bool * In tests, give selection as type int * lint * add example to release note * fix typo in test case * add check of test Co-authored-by: Takashi Imamichi <31178928+t-imamichi@users.noreply.github.com> * lint --------- Co-authored-by: Ian Hincks Co-authored-by: Takashi Imamichi <31178928+t-imamichi@users.noreply.github.com> (cherry picked from commit 0c03808cd177c857a3458df7e58ca9e800185577) Co-authored-by: aeddins-ibm <60495383+aeddins-ibm@users.noreply.github.com> --- qiskit/primitives/containers/bit_array.py | 91 +++++++++++++++++++ .../bitarray-postselect-659b8f7801ccaa60.yaml | 11 +++ .../primitives/containers/test_bit_array.py | 79 ++++++++++++++++ 3 files changed, 181 insertions(+) create mode 100644 releasenotes/notes/bitarray-postselect-659b8f7801ccaa60.yaml diff --git a/qiskit/primitives/containers/bit_array.py b/qiskit/primitives/containers/bit_array.py index 11cd91a96521..29ff3240f3bf 100644 --- a/qiskit/primitives/containers/bit_array.py +++ b/qiskit/primitives/containers/bit_array.py @@ -470,6 +470,97 @@ def slice_shots(self, indices: int | Sequence[int]) -> "BitArray": arr = arr[..., indices, :] return BitArray(arr, self.num_bits) + def postselect( + self, + indices: Sequence[int] | int, + selection: Sequence[bool | int] | bool | int, + ) -> BitArray: + """Post-select this bit array based on sliced equality with a given bitstring. + + .. note:: + If this bit array contains any shape axes, it is first flattened into a long list of shots + before applying post-selection. This is done because :class:`~BitArray` cannot handle + ragged numbers of shots across axes. + + Args: + indices: A list of the indices of the cbits on which to postselect. + If this bit array was produced by a sampler, then an index ``i`` corresponds to the + :class:`~.ClassicalRegister` location ``creg[i]`` (as in :meth:`~slice_bits`). + Negative indices are allowed. + + selection: A list of binary values (will be cast to ``bool``) of length matching + ``indices``, with ``indices[i]`` corresponding to ``selection[i]``. Shots will be + discarded unless all cbits specified by ``indices`` have the values given by + ``selection``. + + Returns: + A new bit array with ``shape=(), num_bits=data.num_bits, num_shots<=data.num_shots``. + + Raises: + IndexError: If ``max(indices)`` is greater than or equal to :attr:`num_bits`. + IndexError: If ``min(indices)`` is less than negative :attr:`num_bits`. + ValueError: If the lengths of ``selection`` and ``indices`` do not match. + """ + if isinstance(indices, int): + indices = (indices,) + if isinstance(selection, (bool, int)): + selection = (selection,) + selection = np.asarray(selection, dtype=bool) + + num_indices = len(indices) + + if len(selection) != num_indices: + raise ValueError("Lengths of indices and selection do not match.") + + num_bytes = self._array.shape[-1] + indices = np.asarray(indices) + + if num_indices > 0: + if indices.max() >= self.num_bits: + raise IndexError( + f"index {int(indices.max())} out of bounds for the number of bits {self.num_bits}." + ) + if indices.min() < -self.num_bits: + raise IndexError( + f"index {int(indices.min())} out of bounds for the number of bits {self.num_bits}." + ) + + flattened = self.reshape((), self.size * self.num_shots) + + # If no conditions, keep all data, but flatten as promised: + if num_indices == 0: + return flattened + + # Make negative bit indices positive: + indices %= self.num_bits + + # Handle special-case of contradictory conditions: + if np.intersect1d(indices[selection], indices[np.logical_not(selection)]).size > 0: + return BitArray(np.empty((0, num_bytes), dtype=np.uint8), num_bits=self.num_bits) + + # Recall that creg[0] is the LSb: + byte_significance, bit_significance = np.divmod(indices, 8) + # least-significant byte is at last position: + byte_idx = (num_bytes - 1) - byte_significance + # least-significant bit is at position 0: + bit_offset = bit_significance.astype(np.uint8) + + # Get bitpacked representation of `indices` (bitmask): + bitmask = np.zeros(num_bytes, dtype=np.uint8) + np.bitwise_or.at(bitmask, byte_idx, np.uint8(1) << bit_offset) + + # Get bitpacked representation of `selection` (desired bitstring): + selection_bytes = np.zeros(num_bytes, dtype=np.uint8) + ## This assumes no contradictions present, since those were already checked for: + np.bitwise_or.at( + selection_bytes, byte_idx, np.asarray(selection, dtype=np.uint8) << bit_offset + ) + + return BitArray( + flattened._array[((flattened._array & bitmask) == selection_bytes).all(axis=-1)], + num_bits=self.num_bits, + ) + def expectation_values(self, observables: ObservablesArrayLike) -> NDArray[np.float64]: """Compute the expectation values of the provided observables, broadcasted against this bit array. diff --git a/releasenotes/notes/bitarray-postselect-659b8f7801ccaa60.yaml b/releasenotes/notes/bitarray-postselect-659b8f7801ccaa60.yaml new file mode 100644 index 000000000000..33ce17bafa8d --- /dev/null +++ b/releasenotes/notes/bitarray-postselect-659b8f7801ccaa60.yaml @@ -0,0 +1,11 @@ +--- +features_primitives: + - | + Added a new method :meth:`.BitArray.postselect` that returns all shots containing specified bit values. + Example usage:: + + from qiskit.primitives.containers import BitArray + + ba = BitArray.from_counts({'110': 2, '100': 4, '000': 3}) + print(ba.postselect([0,2], [0,1]).get_counts()) + # {'110': 2, '100': 4} diff --git a/test/python/primitives/containers/test_bit_array.py b/test/python/primitives/containers/test_bit_array.py index 4aeeba854b33..bd41d127689d 100644 --- a/test/python/primitives/containers/test_bit_array.py +++ b/test/python/primitives/containers/test_bit_array.py @@ -719,3 +719,82 @@ def test_expectation_values(self): _ = ba.expectation_values("Z") with self.assertRaisesRegex(ValueError, "is not diagonal"): _ = ba.expectation_values("X" * ba.num_bits) + + def test_postselection(self): + """Test the postselection method.""" + + flat_data = np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1], + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + ], + dtype=bool, + ) + + shaped_data = np.array( + [ + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1], + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + ], + [ + [1, 0, 1, 0, 1, 0, 1, 0, 1, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ], + ] + ], + dtype=bool, + ) + + for dataname, bool_array in zip(["flat", "shaped"], [flat_data, shaped_data]): + + bit_array = BitArray.from_bool_array(bool_array, order="little") + # indices value of i <-> creg[i] <-> bool_array[..., i] + + num_bits = bool_array.shape[-1] + bool_array = bool_array.reshape(-1, num_bits) + + test_cases = [ + ("basic", [0, 1], [0, 0]), + ("multibyte", [0, 9], [0, 1]), + ("repeated", [5, 5, 5], [0, 0, 0]), + ("contradict", [5, 5, 5], [1, 0, 0]), + ("unsorted", [5, 0, 9, 3], [1, 0, 1, 0]), + ("negative", [-5, 1, -2, -10], [1, 0, 1, 0]), + ("negcontradict", [4, -6], [1, 0]), + ("trivial", [], []), + ("bareindex", 6, 0), + ] + + for name, indices, selection in test_cases: + with self.subTest("_".join([dataname, name])): + postselected_bools = np.unpackbits( + bit_array.postselect(indices, selection).array[:, ::-1], + count=num_bits, + axis=-1, + bitorder="little", + ).astype(bool) + if isinstance(indices, int): + indices = (indices,) + if isinstance(selection, bool): + selection = (selection,) + answer = bool_array[np.all(bool_array[:, indices] == selection, axis=-1)] + if name in ["contradict", "negcontradict"]: + self.assertEqual(len(answer), 0) + else: + self.assertGreater(len(answer), 0) + np.testing.assert_equal(postselected_bools, answer) + + error_cases = [ + ("aboverange", [0, 6, 10], [True, True, False], IndexError), + ("belowrange", [0, 6, -11], [True, True, False], IndexError), + ("mismatch", [0, 1, 2], [False, False], ValueError), + ] + for name, indices, selection, error in error_cases: + with self.subTest(dataname + "_" + name): + with self.assertRaises(error): + bit_array.postselect(indices, selection)