Skip to content

Commit

Permalink
Support broadcasting in measure and measure_with_samples (#4238)
Browse files Browse the repository at this point in the history
* Support broadcasting in state measurements

* docs for is_state_batched

* Support broadcasting in sample measurements

* Apply suggestions from code review

Co-authored-by: Christina Lee <christina@xanadu.ai>

* black

* black

* Remove total_copies

* fix

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
eddddddy and albi3ro authored Jun 16, 2023
1 parent 5824a8e commit dc3078a
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 29 deletions.
65 changes: 48 additions & 17 deletions pennylane/devices/qubit/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,67 +28,92 @@


def state_diagonalizing_gates(
measurementprocess: StateMeasurement, state: TensorLike
measurementprocess: StateMeasurement, state: TensorLike, is_state_batched: bool = False
) -> TensorLike:
"""Apply a measurement to state when the measurement process has an observable with diagonalizing gates.
Args:
measurementprocess (StateMeasurement): measurement to apply to the state
state (TensorLike): state to apply the measurement to
is_state_batched (bool): whether the state is batched or not
Returns:
TensorLike: the result of the measurement
"""
for op in measurementprocess.diagonalizing_gates():
state = apply_operation(op, state)
state = apply_operation(op, state, is_state_batched=is_state_batched)

total_indices = len(state.shape)
total_indices = len(state.shape) - is_state_batched
wires = Wires(range(total_indices))
return measurementprocess.process_state(math.flatten(state), wires)

flattened_state = (
math.reshape(state, (state.shape[0], -1)) if is_state_batched else math.flatten(state)
)
return measurementprocess.process_state(flattened_state, wires)


def csr_dot_products(measurementprocess: ExpectationMP, state: TensorLike) -> TensorLike:
def csr_dot_products(
measurementprocess: ExpectationMP, state: TensorLike, is_state_batched: bool = False
) -> TensorLike:
"""Measure the expectation value of an observable using dot products between ``scipy.csr_matrix``
representations.
Args:
measurementprocess (ExpectationMP): measurement process to apply to the state
state (TensorLike): the state to measure
is_state_batched (bool): whether the state is batched or not
Returns:
TensorLike: the result of the measurement
"""
total_wires = len(state.shape)
total_wires = len(state.shape) - is_state_batched
Hmat = measurementprocess.obs.sparse_matrix(wire_order=list(range(total_wires)))
state = math.toarray(state).flatten()

# Find the expectation value using the <\psi|H|\psi> matrix contraction
bra = csr_matrix(math.conj(state))
ket = csr_matrix(state[..., None])
new_ket = csr_matrix.dot(Hmat, ket)
res = csr_matrix.dot(bra, new_ket).toarray()[0]
if is_state_batched:
state = math.toarray(state).reshape(math.shape(state)[0], -1)

bra = csr_matrix(math.conj(state))
ket = csr_matrix(state)
new_bra = bra.dot(Hmat)
res = new_bra.multiply(ket).sum(axis=1).getA()

else:
state = math.toarray(state).flatten()

# Find the expectation value using the <\psi|H|\psi> matrix contraction
bra = csr_matrix(math.conj(state))
ket = csr_matrix(state[..., None])
new_ket = csr_matrix.dot(Hmat, ket)
res = csr_matrix.dot(bra, new_ket).toarray()[0]

return math.real(math.squeeze(res))


def sum_of_terms_method(measurementprocess: ExpectationMP, state: TensorLike) -> TensorLike:
def sum_of_terms_method(
measurementprocess: ExpectationMP, state: TensorLike, is_state_batched: bool = False
) -> TensorLike:
"""Measure the expecation value of the state when the measured observable is a ``Hamiltonian`` or ``Sum``
and it must be backpropagation compatible.
Args:
measurementprocess (ExpectationMP): measurement process to apply to the state
state (TensorLike): the state to measure
is_state_batched (bool): whether the state is batched or not
Returns:
TensorLike: the result of the measurement
"""
if isinstance(measurementprocess.obs, Sum):
# Recursively call measure on each term, so that the best measurement method can
# be used for each term
return sum(measure(ExpectationMP(term), state) for term in measurementprocess.obs)
return sum(
measure(ExpectationMP(term), state, is_state_batched=is_state_batched)
for term in measurementprocess.obs
)
# else hamiltonian
return sum(
c * measure(ExpectationMP(t), state) for c, t in zip(*measurementprocess.obs.terms())
c * measure(ExpectationMP(t), state, is_state_batched=is_state_batched)
for c, t in zip(*measurementprocess.obs.terms())
)


Expand All @@ -100,6 +125,7 @@ def get_measurement_function(
Args:
measurementprocess (MeasurementProcess): measurement process to apply to the state
state (TensorLike): the state to measure
is_state_batched (bool): whether the state is batched or not
Returns:
Callable: function that returns the measurement result
Expand Down Expand Up @@ -127,14 +153,19 @@ def get_measurement_function(
raise NotImplementedError


def measure(measurementprocess: MeasurementProcess, state: TensorLike) -> TensorLike:
def measure(
measurementprocess: MeasurementProcess, state: TensorLike, is_state_batched: bool = False
) -> TensorLike:
"""Apply a measurement process to a state.
Args:
measurementprocess (MeasurementProcess): measurement process to apply to the state
state (TensorLike): the state to measure
is_state_batched (bool): whether the state is batched or not
Returns:
Tensorlike: the result of the measurement
"""
return get_measurement_function(measurementprocess, state)(measurementprocess, state)
return get_measurement_function(measurementprocess, state)(
measurementprocess, state, is_state_batched
)
57 changes: 45 additions & 12 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def measure_with_samples(
mp: SampleMeasurement, state: np.ndarray, shots: Shots, rng=None
mp: SampleMeasurement, state: np.ndarray, shots: Shots, is_state_batched: bool = False, rng=None
) -> TensorLike:
"""
Returns the samples of the measurement process performed on the given state.
Expand All @@ -33,6 +33,7 @@ def measure_with_samples(
mp (~.measurements.SampleMeasurement): The sample measurement to perform
state (np.ndarray[complex]): The state vector to sample from
shots (~.measurements.Shots): The number of samples to take
is_state_batched (bool): whether the state is batched or not
rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A
seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
If no value is provided, a default RNG will be used.
Expand All @@ -47,7 +48,10 @@ def measure_with_samples(

def _sum_for_single_shot(s):
return sum(
c * measure_with_samples(ExpectationMP(t), state, s, rng=rng)
c
* measure_with_samples(
ExpectationMP(t), state, s, is_state_batched=is_state_batched, rng=rng
)
for c, t in zip(*mp.obs.terms())
)

Expand All @@ -58,18 +62,23 @@ def _sum_for_single_shot(s):

def _sum_for_single_shot(s):
return sum(
measure_with_samples(ExpectationMP(t), state, s, rng=rng) for t in mp.obs
measure_with_samples(
ExpectationMP(t), state, s, is_state_batched=is_state_batched, rng=rng
)
for t in mp.obs
)

unsqueezed_results = tuple(_sum_for_single_shot(Shots(s)) for s in shots)
return unsqueezed_results if shots.has_partitioned_shots else unsqueezed_results[0]

# measure with the usual method (rotate into the measurement basis)
return _measure_with_samples_diagonalizing_gates(mp, state, shots, rng=rng)
return _measure_with_samples_diagonalizing_gates(
mp, state, shots, is_state_batched=is_state_batched, rng=rng
)


def _measure_with_samples_diagonalizing_gates(
mp: SampleMeasurement, state: np.ndarray, shots: Shots, rng=None
mp: SampleMeasurement, state: np.ndarray, shots: Shots, is_state_batched: bool = False, rng=None
) -> TensorLike:
"""
Returns the samples of the measurement process performed on the given state,
Expand All @@ -80,6 +89,7 @@ def _measure_with_samples_diagonalizing_gates(
mp (~.measurements.SampleMeasurement): The sample measurement to perform
state (np.ndarray[complex]): The state vector to sample from
shots (~.measurements.Shots): The number of samples to take
is_state_batched (bool): whether the state is batched or not
rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A
seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
If no value is provided, a default RNG will be used.
Expand All @@ -90,9 +100,12 @@ def _measure_with_samples_diagonalizing_gates(
# apply diagonalizing gates
pre_rotated_state = state
for op in mp.diagonalizing_gates():
pre_rotated_state = apply_operation(op, pre_rotated_state)
pre_rotated_state = apply_operation(
op, pre_rotated_state, is_state_batched=is_state_batched
)

wires = qml.wires.Wires(range(len(state.shape)))
total_indices = len(state.shape) - is_state_batched
wires = qml.wires.Wires(range(total_indices))

# if there is a shot vector, build a list containing results for each shot entry
if shots.has_partitioned_shots:
Expand All @@ -101,7 +114,9 @@ def _measure_with_samples_diagonalizing_gates(
# currently we call sample_state for each shot entry, but it may be
# better to call sample_state just once with total_shots, then use
# the shot_range keyword argument
samples = sample_state(pre_rotated_state, shots=s, wires=wires, rng=rng)
samples = sample_state(
pre_rotated_state, shots=s, is_state_batched=is_state_batched, wires=wires, rng=rng
)

if not isinstance(processed := mp.process_samples(samples, wires), dict):
processed = qml.math.squeeze(processed)
Expand All @@ -110,21 +125,30 @@ def _measure_with_samples_diagonalizing_gates(

return tuple(processed_samples)

samples = sample_state(pre_rotated_state, shots=shots.total_shots, wires=wires, rng=rng)
samples = sample_state(
pre_rotated_state,
shots=shots.total_shots,
is_state_batched=is_state_batched,
wires=wires,
rng=rng,
)

if not isinstance(processed := mp.process_samples(samples, wires), dict):
processed = qml.math.squeeze(processed)

return processed


def sample_state(state, shots: int, wires=None, rng=None) -> np.ndarray:
def sample_state(
state, shots: int, is_state_batched: bool = False, wires=None, rng=None
) -> np.ndarray:
"""
Returns a series of samples of a state.
Args:
state (array[complex]): A state vector to be sampled
shots (int): The number of samples to take
is_state_batched (bool): whether the state is batched or not
wires (Sequence[int]): The wires to sample
rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]):
A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
Expand All @@ -134,12 +158,21 @@ def sample_state(state, shots: int, wires=None, rng=None) -> np.ndarray:
ndarray[bool]: Sample values of the shape (shots, num_wires)
"""
rng = np.random.default_rng(rng)
state_wires = qml.wires.Wires(range(len(state.shape)))

total_indices = len(state.shape) - is_state_batched
state_wires = qml.wires.Wires(range(total_indices))

wires_to_sample = wires or state_wires
num_wires = len(wires_to_sample)
basis_states = np.arange(2**num_wires)

probs = qml.probs(wires=wires_to_sample).process_state(state, state_wires)
samples = rng.choice(basis_states, shots, p=probs)

if is_state_batched:
# rng.choice doesn't support broadcasting
samples = np.stack([rng.choice(basis_states, shots, p=p) for p in probs])
else:
samples = rng.choice(basis_states, shots, p=probs)

powers_of_two = 1 << np.arange(num_wires, dtype=np.int64)[::-1]
return (samples[..., None] & powers_of_two).astype(np.bool8)
71 changes: 71 additions & 0 deletions tests/devices/qubit/test_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,77 @@ def test_sum_expval_eigs(self, obs, expected):
assert np.allclose(res, expected)


class TestBroadcasting:
"""Test that measurements work when the state has a batch dim"""

@pytest.mark.parametrize(
"measurement, expected",
[
(
qml.state(),
np.array(
[
[0, 0, 0, 1],
[1 / np.sqrt(2), 0, 1 / np.sqrt(2), 0],
[1 / 2, 1 / 2, 1 / 2, 1 / 2],
]
),
),
(
qml.density_matrix(wires=[0, 1]),
np.array(
[
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1]],
[[1 / 2, 0, 1 / 2, 0], [0, 0, 0, 0], [1 / 2, 0, 1 / 2, 0], [0, 0, 0, 0]],
[
[1 / 4, 1 / 4, 1 / 4, 1 / 4],
[1 / 4, 1 / 4, 1 / 4, 1 / 4],
[1 / 4, 1 / 4, 1 / 4, 1 / 4],
[1 / 4, 1 / 4, 1 / 4, 1 / 4],
],
]
),
),
(
qml.probs(wires=[0, 1]),
np.array([[0, 0, 0, 1], [1 / 2, 0, 1 / 2, 0], [1 / 4, 1 / 4, 1 / 4, 1 / 4]]),
),
(qml.expval(qml.PauliZ(1)), np.array([-1, 1, 0])),
(qml.var(qml.PauliZ(1)), np.array([0, 0, 1])),
],
)
def test_state_measurement(self, measurement, expected):
"""Test that broadcasting works for regular state measurements"""
state = [
np.array([[0, 0], [0, 1]]),
np.array([[1, 0], [1, 0]]) / np.sqrt(2),
np.array([[1, 1], [1, 1]]) / 2,
]
state = np.stack(state)

res = measure(measurement, state, is_state_batched=True)
assert np.allclose(res, expected)

def test_sparse_hamiltonian(self):
"""Test that broadcasting works for expectation values of SparseHamiltonians"""
H = qml.Hamiltonian([2], [qml.PauliZ(1)])
measurement = qml.expval(H)

state = [
np.array([[0, 0], [0, 1]]),
np.array([[1, 0], [1, 0]]) / np.sqrt(2),
np.array([[1, 1], [1, 1]]) / 2,
]
state = np.stack(state)

measurement_fn = get_measurement_function(measurement, state)
assert measurement_fn is csr_dot_products

res = measure(measurement, state, is_state_batched=True)
expected = np.array([-2, 2, 0])
assert np.allclose(res, expected)


class TestSumOfTermsDifferentiability:
@staticmethod
def f(scale, coeffs, n_wires=10, offset=0.1, convert_to_hamiltonian=False):
Expand Down
Loading

0 comments on commit dc3078a

Please sign in to comment.