Skip to content

Commit

Permalink
Use SampleMPs in dynamic_one_shot (#5486)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
The native MCM workflow breaks the device API where a sequence of
MeasurementProcess objects is expected in the output.

**Description of the Change:**
Introduce SampleMPs in the auxiliary tape. Pass the mid_measurements
dictionary around simulate and sampling to return the correct sample
measurements. Modify the dynamic_one_shot transform post-processing
function accordingly.

**Benefits:**
Conform to current API.
Road to jax.jit support.

**Possible Drawbacks:**
Ad hoc post-processing required in measure_with_samples.

**Related GitHub Issues:**
[sc-60945]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
vincentmr and albi3ro committed Apr 22, 2024
1 parent be8a22e commit 44018e9
Show file tree
Hide file tree
Showing 12 changed files with 399 additions and 111 deletions.
21 changes: 12 additions & 9 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
>>> circuit()
tensor([1.+6.123234e-17j, 0.-6.123234e-17j], requires_grad=True)
```

* The `qml.AmplitudeAmplification` operator is introduced, which is a high-level interface for amplitude amplification and its variants.
[(#5160)](https://github.com/PennyLaneAI/pennylane/pull/5160)

Expand All @@ -185,7 +185,7 @@
return qml.probs(wires=range(3))

```

```pycon
>>> print(np.round(circuit(), 3))
[0.013, 0.013, 0.91, 0.013, 0.013, 0.013, 0.013, 0.013]
Expand All @@ -212,7 +212,7 @@
but for usage with new operator arithmetic.
[(#5216)](https://github.com/PennyLaneAI/pennylane/pull/5216)

* The `qml.TrotterProduct` operator now supports error estimation functionality.
* The `qml.TrotterProduct` operator now supports error estimation functionality.
[(#5384)](https://github.com/PennyLaneAI/pennylane/pull/5384)

```pycon
Expand Down Expand Up @@ -245,18 +245,18 @@
* The `molecular_hamiltonian` function calls `PySCF` directly when `method='pyscf'` is selected.
[(#5118)](https://github.com/PennyLaneAI/pennylane/pull/5118)

* The generators in the source code return operators consistent with the global setting for
`qml.operator.active_new_opmath()` wherever possible. `Sum`, `SProd` and `Prod` instances
will be returned even after disabling the new operator arithmetic in cases where they offer
* The generators in the source code return operators consistent with the global setting for
`qml.operator.active_new_opmath()` wherever possible. `Sum`, `SProd` and `Prod` instances
will be returned even after disabling the new operator arithmetic in cases where they offer
additional functionality not available using legacy operators.
[(#5253)](https://github.com/PennyLaneAI/pennylane/pull/5253)
[(#5410)](https://github.com/PennyLaneAI/pennylane/pull/5410)
[(#5411)](https://github.com/PennyLaneAI/pennylane/pull/5411)
[(#5411)](https://github.com/PennyLaneAI/pennylane/pull/5411)
[(#5421)](https://github.com/PennyLaneAI/pennylane/pull/5421)

* Upgraded `null.qubit` to the new device API. Also, added support for all measurements and various modes of differentiation.
[(#5211)](https://github.com/PennyLaneAI/pennylane/pull/5211)

* `ApproxTimeEvolution` is now compatible with any operator that defines a `pauli_rep`.
[(#5362)](https://github.com/PennyLaneAI/pennylane/pull/5362)

Expand Down Expand Up @@ -338,6 +338,9 @@

<h3>Breaking changes 💔</h3>

* Use `SampleMP`s in the `dynamic_one_shot` transform to get back the values of the mid-circuit measurements.
[(#5486)](https://github.com/PennyLaneAI/pennylane/pull/5486)

* Operator dunder methods now combine like-operator arithmetic classes via `lazy=False`. This reduces the chance of `RecursionError` and makes nested
operators easier to work with.
[(#5478)](https://github.com/PennyLaneAI/pennylane/pull/5478)
Expand All @@ -359,7 +362,7 @@

* `qml.pauli.pauli_mult` and `qml.pauli.pauli_mult_with_phase` are now removed. Instead, you should use `qml.simplify(qml.prod(pauli_1, pauli_2))` to get the reduced operator.
[(#5324)](https://github.com/PennyLaneAI/pennylane/pull/5324)

```pycon
>>> op = qml.simplify(qml.prod(qml.PauliX(0), qml.PauliZ(0)))
>>> op
Expand Down
25 changes: 15 additions & 10 deletions pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,6 @@ def execute(self, circuit, **kwargs):
)
if has_mcm:
mid_measurements = kwargs["mid_measurements"]
mid_values = np.array(tuple(mid_measurements.values()))
if np.any(mid_values == -1):
for k, v in tuple(mid_measurements.items()):
if v == -1:
mid_measurements.pop(k)
return None, mid_measurements

# generate computational basis samples
sample_type = (SampleMP, CountsMP, ClassicalShadowMP, ShadowExpvalMP)
Expand All @@ -308,13 +302,24 @@ def execute(self, circuit, **kwargs):
self.apply([qml.adjoint(g, lazy=False) for g in reversed(diagonalizing_gates)])

# compute the required statistics
if has_mcm:
n_mcms = len(mid_measurements)
stat_circuit = qml.tape.QuantumScript(
circuit.operations,
circuit.measurements[0:-n_mcms],
shots=1,
trainable_params=circuit.trainable_params,
)
else:
stat_circuit = circuit
if self._shot_vector is not None:
results = self.shot_vec_statistics(circuit)
results = self.shot_vec_statistics(stat_circuit)

else:
results = self.statistics(circuit)
results = self.statistics(stat_circuit)
if has_mcm:
results.extend(list(mid_measurements.values()))
single_measurement = len(circuit.measurements) == 1

results = results[0] if single_measurement else tuple(results)
# increment counter for number of executions of qubit device
self._num_executions += 1
Expand All @@ -336,7 +341,7 @@ def execute(self, circuit, **kwargs):
)
self.tracker.record()

return (results, mid_measurements) if has_mcm else results
return results

def shot_vec_statistics(self, circuit: QuantumTape):
"""Process measurement results from circuit execution using a device
Expand Down
21 changes: 8 additions & 13 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

from functools import singledispatch
from string import ascii_letters as alphabet

import numpy as np

import pennylane as qml

from pennylane import math
from pennylane.measurements import MidMeasureMP
from pennylane.measurements import MidMeasureMP, Shots
from pennylane.ops import Conditional

SQRT2INV = 1 / math.sqrt(2)
Expand Down Expand Up @@ -261,21 +261,16 @@ def apply_mid_measure(
if is_state_batched:
raise ValueError("MidMeasureMP cannot be applied to batched states.")
if not np.allclose(np.linalg.norm(state), 1.0):
mid_measurements[op] = 0
mid_measurements[op] = -1
return np.zeros_like(state)
wire = op.wires
probs = qml.devices.qubit.measure(qml.probs(wire), state)

try: # pragma: no cover
sample = np.random.binomial(1, probs[1])
except ValueError as e: # pragma: no cover
if probs[1] > 1: # MachEps error, safe to catch
sample = np.random.binomial(1, np.round(probs[1], 15))
else: # Other general error, continue to fail
raise e

sample = qml.devices.qubit.sampling.measure_with_samples(
[qml.sample(wires=wire)], state, Shots(1)
)
sample = int(sample[0])
mid_measurements[op] = sample
if op.postselect is not None and sample != op.postselect:
mid_measurements[op] = -1
return np.zeros_like(state)
axis = wire.toarray()[0]
slices = [slice(None)] * qml.math.ndim(state)
Expand Down
30 changes: 22 additions & 8 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to sample a state."""
from typing import List, Union, Tuple
from typing import List, Tuple, Union

import numpy as np

import pennylane as qml
from pennylane.ops import Sum, Hamiltonian, LinearCombination
from pennylane.measurements import (
SampleMeasurement,
Shots,
ExpectationMP,
ClassicalShadowMP,
ShadowExpvalMP,
CountsMP,
ExpectationMP,
SampleMeasurement,
ShadowExpvalMP,
Shots,
)
from pennylane.ops import Hamiltonian, LinearCombination, Sum
from pennylane.typing import TensorLike

from .apply_operation import apply_operation
from .measure import flatten_state

Expand Down Expand Up @@ -165,20 +167,21 @@ def _apply_diagonalizing_gates(

# pylint:disable = too-many-arguments
def measure_with_samples(
mps: List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]],
measurements: List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]],
state: np.ndarray,
shots: Shots,
is_state_batched: bool = False,
rng=None,
prng_key=None,
mid_measurements: dict = None,
) -> List[TensorLike]:
"""
Returns the samples of the measurement process performed on the given state.
This function assumes that the user-defined wire labels in the measurement process
have already been mapped to integer wires used in the device.
Args:
mp (List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]]):
measurements (List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]]):
The sample measurements to perform
state (np.ndarray[complex]): The state vector to sample from
shots (Shots): The number of samples to take
Expand All @@ -188,15 +191,22 @@ def measure_with_samples(
If no value is provided, a default RNG will be used.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. Only for simulation using JAX.
mid_measurements (None, dict): Dictionary of mid-circuit measurements
Returns:
List[TensorLike[Any]]: Sample measurement results
"""
# last N measurements are sampling MCMs in ``dynamic_one_shot`` execution mode
mps = measurements[0 : -len(mid_measurements)] if mid_measurements else measurements
skip_measure = any(v == -1 for v in mid_measurements.values()) if mid_measurements else False

groups, indices = _group_measurements(mps)

all_res = []
for group in groups:
if skip_measure:
all_res.extend([None] * len(group))
continue
if isinstance(group[0], ExpectationMP) and isinstance(
group[0].obs, (Hamiltonian, LinearCombination)
):
Expand All @@ -223,6 +233,10 @@ def measure_with_samples(
res for _, res in sorted(list(enumerate(all_res)), key=lambda r: flat_indices[r[0]])
)

# append MCM samples
if mid_measurements:
sorted_res += tuple(mid_measurements.values())

# put the shot vector axis before the measurement axis
if shots.has_partitioned_shots:
sorted_res = tuple(zip(*sorted_res))
Expand Down
37 changes: 22 additions & 15 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@
# pylint: disable=protected-access
from typing import Optional

from numpy.random import default_rng
import numpy as np
from numpy.random import default_rng

import pennylane as qml
from pennylane.measurements import (
MidMeasureMP,
)
from pennylane.measurements import MidMeasureMP
from pennylane.typing import Result

from .initialize_state import create_initial_state
from .apply_operation import apply_operation
from .initialize_state import create_initial_state
from .measure import measure
from .sampling import jax_random_split, measure_with_samples


INTERFACE_TO_LIKE = {
# map interfaces known by autoray to themselves
None: None,
Expand Down Expand Up @@ -153,7 +150,10 @@ def get_final_state(circuit, debugger=None, interface=None, mid_measurements=Non
return state, is_state_batched


def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=None) -> Result:
# pylint: disable=too-many-arguments
def measure_final_state(
circuit, state, is_state_batched, rng=None, prng_key=None, mid_measurements: dict = None
) -> Result:
"""
Perform the measurements required by the circuit on the provided state.
Expand All @@ -170,15 +170,19 @@ def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=Non
the key to the JAX pseudo random number generator. Only for simulation using JAX.
If None, the default ``sample_state`` function and a ``numpy.random.default_rng``
will be for sampling.
mid_measurements (None, dict): Dictionary of mid-circuit measurements
Returns:
Tuple[TensorLike]: The measurement results
"""

circuit = circuit.map_to_standard_wires()

# analytic case

if not circuit.shots:
# analytic case
if mid_measurements is not None:
raise TypeError("Native mid-circuit measurements are only supported with finite shots.")

if len(circuit.measurements) == 1:
return measure(circuit.measurements[0], state, is_state_batched=is_state_batched)
Expand All @@ -197,6 +201,7 @@ def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=Non
is_state_batched=is_state_batched,
rng=rng,
prng_key=prng_key,
mid_measurements=mid_measurements,
)

if len(circuit.measurements) == 1:
Expand Down Expand Up @@ -283,13 +288,15 @@ def simulate_one_shot_native_mcm(
dict: The mid-circuit measurement results of the simulation
"""
_, key = jax_random_split(prng_key)
mcm_dict = {}
mid_measurements = {}
state, is_state_batched = get_final_state(
circuit, debugger=debugger, interface=interface, mid_measurements=mcm_dict
circuit, debugger=debugger, interface=interface, mid_measurements=mid_measurements
)
if not np.allclose(np.linalg.norm(state), 1.0):
return None, mcm_dict
return (
measure_final_state(circuit, state, is_state_batched, rng=rng, prng_key=key),
mcm_dict,
return measure_final_state(
circuit,
state,
is_state_batched,
rng=rng,
prng_key=key,
mid_measurements=mid_measurements,
)
Loading

0 comments on commit 44018e9

Please sign in to comment.