Skip to content

Commit

Permalink
Default postselect_mode=None, raise error with DM+jax_jit+hw_like, docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mudit2812 committed Jun 3, 2024
1 parent 85258d4 commit e0555e5
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 22 deletions.
7 changes: 4 additions & 3 deletions doc/introduction/measurements.rst
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,10 @@ PennyLane. For ease of use, we provide the following configuration options to us
-1.0000000e+00, -2.1474836e+09, -1.0000000e+00, -2.1474836e+09,
-1.0000000e+00, -1.0000000e+00], dtype=float32, weak_type=True)

* If ``mcm_method="deferred"``, then using ``postselect_mode="hw-like"`` will have the same behaviour as when
``postselect_mode="fill-shots"``. This is due to the limitations of the :func:`~pennylane.defer_measurements`
transform, and this behaviour will change in the future to be more consistent with ``mcm_method="one-shot"``.
* When using ``jax.jit``, using ``mcm_method="deferred"`` is not supported with ``postselect_mode="hw-like"``.
Therefore, the default behaviour will be to use ``postselect_mode="fill-shots"``. This is due to limitations
of the :func:`~pennylane.defer_measurements` transform, and this behaviour will change in the future to be
more consistent with ``mcm_method="one-shot"``.

Changing the number of shots
----------------------------
Expand Down
2 changes: 1 addition & 1 deletion pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def batch_execute(self, circuits, **kwargs):
)

if self.capabilities().get("supports_mid_measure", False):
kwargs.setdefault("postselect_mode", "hw-like")
kwargs.setdefault("postselect_mode", None)

results = []
for circuit in circuits:
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ class MCMConfig:
for each shot separately. If not specified, the device will decide which method to
use."""

postselect_mode: str = "hw-like"
postselect_mode: Optional[str] = None
"""Configuration for handling shots with mid-circuit measurement postselection. If
``"hw-like"``, invalid shots will be discarded and only results for valid shots will
be returned. If ``"fill-shots"``, results corresponding to the original number of
shots will be returned."""
shots will be returned. If not specified, the device will decide which mode to use."""


# pylint: disable=too-many-instance-attributes
Expand Down
6 changes: 3 additions & 3 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def apply_operation(
interface (str): The machine learning interface of the state
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. ``"hw-like"`` by default.
keep the same number of shots. ``None`` by default.
rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. Only for simulation using JAX.
Expand Down Expand Up @@ -302,7 +302,7 @@ def apply_mid_measure(
mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. ``"hw-like"`` by default.
keep the same number of shots. ``None`` by default.
rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. Only for simulation using JAX.
Expand All @@ -314,7 +314,7 @@ def apply_mid_measure(
mid_measurements = execution_kwargs.get("mid_measurements", None)
rng = execution_kwargs.get("rng", None)
prng_key = execution_kwargs.get("prng_key", None)
postselect_mode = execution_kwargs.get("postselect_mode", "hw-like")
postselect_mode = execution_kwargs.get("postselect_mode", None)

if is_state_batched:
raise ValueError("MidMeasureMP cannot be applied to batched states.")
Expand Down
11 changes: 7 additions & 4 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ def _postselection_postprocess(state, is_state_batched, shots, **execution_kwarg

rng = execution_kwargs.get("rng", None)
prng_key = execution_kwargs.get("prng_key", None)
postselect_mode = execution_kwargs.get("postselect_mode", "hw-like")
postselect_mode = execution_kwargs.get("postselect_mode", None)

if postselect_mode == "hw-like" and qml.math.is_abstract(state):
raise ValueError("Using postselect_mode='hw-like' is not supported with jax-jit.")

# The floor function is being used here so that a norm very close to zero becomes exactly
# equal to zero so that the state can become invalid. This way, execution can continue, and
Expand Down Expand Up @@ -138,7 +141,7 @@ def get_final_state(circuit, debugger=None, **execution_kwargs):
If None, a ``numpy.random.default_rng`` will be used for sampling.
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. Default is ``"hw-like"``.
keep the same number of shots. Default is ``None``.
Returns:
Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and
Expand Down Expand Up @@ -276,7 +279,7 @@ def simulate(
interface (str): The machine learning interface to create the initial state with
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. Default is ``"hw-like"``.
keep the same number of shots. Default is ``None``.
Returns:
tuple(TensorLike): The results of the simulation
Expand Down Expand Up @@ -351,7 +354,7 @@ def simulate_one_shot_native_mcm(
interface (str): The machine learning interface to create the initial state with
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. Default is ``"hw-like"``.
keep the same number of shots. Default is ``None``.
Returns:
tuple(TensorLike): The results of the simulation
Expand Down
2 changes: 2 additions & 0 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,8 @@ def cost_fn(params, x):
if interface == "jax-jit" and config.mcm_config.mcm_method == "deferred":
# This is a current limitation of defer_measurements. "hw-like" behaviour is
# not yet accessible.
if config.mcm_config.postselect_mode == "hw-like":
raise ValueError("Using postselect_mode='hw-like' is not supported with jax-jit.")
config.mcm_config.postselect_mode = "fill-shots"

if transform_program is None:
Expand Down
9 changes: 4 additions & 5 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ class QNode:
postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. If
``"hw-like"``, invalid shots will be discarded and only results for valid shots will be returned.
If ``"fill-shots"``, results corresponding to the original number of shots will be returned. The
default is ``"hw-like"``. For usage details, please refer to the
:doc:`main measurements page </introduction/measurements>`.
default is ``None``, in which case the device will automatically choose the best configuration. For
usage details, please refer to the :doc:`main measurements page </introduction/measurements>`.
mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements. Use ``"deferred"``
to apply the deferred measurements principle (using the :func:`~pennylane.defer_measurements` transform),
or ``"one-shot"`` if using finite shots to execute the circuit for each shot separately. If not provided,
Expand Down Expand Up @@ -452,7 +452,7 @@ def __init__(
cachesize=10000,
max_diff=1,
device_vjp=False,
postselect_mode="hw-like",
postselect_mode=None,
mcm_method=None,
**gradient_kwargs,
):
Expand Down Expand Up @@ -520,10 +520,9 @@ def __init__(
self.max_expansion = max_expansion
cache = (max_diff > 1) if cache == "auto" else cache

postselect_mode = postselect_mode or "hw-like"
if mcm_method not in ("deferred", "one-shot", None):
raise ValueError(f"Invalid mid-circuit measurements method '{mcm_method}'.")
if postselect_mode not in ("hw-like", "fill-shots"):
if postselect_mode not in ("hw-like", "fill-shots", None):
raise ValueError(f"Invalid postselection mode '{postselect_mode}'.")

# execution keyword arguments
Expand Down
38 changes: 34 additions & 4 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,9 +1752,9 @@ def f(x):

@pytest.mark.jax
@pytest.mark.parametrize("diff_method", [None, "best"])
def test_defer_measurements_hw_like_with_jit(self, diff_method, mocker):
"""Test that using mcm_method="deferred" with postselect_mode="hw-like" defaults
to behaviour like postselect_mode="fill-shots" when using jax jit."""
def test_defer_measurements_with_jit(self, diff_method, mocker):
"""Test that using mcm_method="deferred" defaults to behaviour like
postselect_mode="fill-shots" when using jax jit."""
import jax # pylint: disable=import-outside-toplevel

shots = 100
Expand All @@ -1765,7 +1765,7 @@ def test_defer_measurements_hw_like_with_jit(self, diff_method, mocker):

dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123))

@qml.qnode(dev, diff_method=diff_method, postselect_mode="hw-like", mcm_method="deferred")
@qml.qnode(dev, diff_method=diff_method, mcm_method="deferred")
def f(x):
qml.RX(x, 0)
qml.measure(0, postselect=postselect)
Expand All @@ -1783,6 +1783,36 @@ def f(x):
assert qml.math.allclose(res, postselect)
assert qml.math.allclose(res_jit, postselect)

@pytest.mark.jax
# @pytest.mark.parametrize("diff_method", [None, "best"])
@pytest.mark.parametrize("diff_method", ["best"])
def test_hw_like_error_with_jit(self, diff_method):
"""Test that an error is raised if attempting to use postselect_mode="hw-like"
with jax jit."""
import jax # pylint: disable=import-outside-toplevel

shots = 100
postselect = 1
param = jax.numpy.array(np.pi / 2)

dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123))

@qml.qnode(dev, diff_method=diff_method, mcm_method="deferred", postselect_mode="hw-like")
def f(x):
qml.RX(x, 0)
qml.measure(0, postselect=postselect)
return qml.sample(wires=0)

f_jit = jax.jit(f)

# Checking that an error is not raised without jit
_ = f(param)

with pytest.raises(
ValueError, match="Using postselect_mode='hw-like' is not supported with jax-jit."
):
_ = f_jit(param)


class TestTapeExpansion:
"""Test that tape expansion within the QNode works correctly"""
Expand Down

0 comments on commit e0555e5

Please sign in to comment.