From 60d2b5ad3f3af5167bddf41a162336e9b3b536fb Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Wed, 29 May 2024 21:13:58 -0400 Subject: [PATCH] Capture measurements with `jax.make_jaxpr`. (#5564) **Context:** **Description of the Change:** **Benefits:** We can capture measurements and keep track of their resulting shapes in an extensible manner. **Possible Drawbacks:** Measurements do need a lot more hand-holding then observables due to the "duel mode" inputs of either wires or observables, and the need to fully specifying the resulting shape when we perform an actual measurment. With the current framework, it will theoretically be extensible to add new measurements, but the process is slightly more error prone. We also now provide duplicate information from `MeasurementProcess.numeric_type`, `MeasurementProcess.shape`, and `MeasurementProcess._abstract_eval`. But the instance-based property and method didn't play well with the need to track that information during an abstract evaluation. **Related GitHub Issues:** [sc-61200] [sc-59452] --------- Co-authored-by: dwierichs Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com> --- doc/releases/changelog-dev.md | 4 +- pennylane/capture/__init__.py | 27 +- pennylane/capture/capture_meta.py | 33 + pennylane/capture/primitives.py | 208 ++++++- pennylane/measurements/classical_shadow.py | 30 + pennylane/measurements/counts.py | 12 + pennylane/measurements/expval.py | 4 +- pennylane/measurements/measurements.py | 94 ++- pennylane/measurements/mid_measure.py | 27 +- pennylane/measurements/mutual_info.py | 21 +- pennylane/measurements/probs.py | 18 +- pennylane/measurements/sample.py | 33 +- pennylane/measurements/state.py | 28 +- pennylane/measurements/var.py | 4 +- tests/capture/test_capture_module.py | 26 + tests/capture/test_measurements_capture.py | 661 +++++++++++++++++++++ 16 files changed, 1196 insertions(+), 34 deletions(-) create mode 100644 tests/capture/test_capture_module.py create mode 100644 tests/capture/test_measurements_capture.py diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 0448d074a7f..0e4fc15314c 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -93,8 +93,8 @@ `m = measure(0); qml.sample(m)`. [(#5673)](https://github.com/PennyLaneAI/pennylane/pull/5673) -* PennyLane operators can now automatically be captured as instructions in JAXPR. See the experimental - `capture` module for more information. +* PennyLane operators and measurements can now automatically be captured as instructions in JAXPR. + [(#5564)](https://github.com/PennyLaneAI/pennylane/pull/5564) [(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511) * The `decompose` transform has an `error` kwarg to specify the type of error that should be raised, diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index a3a35dd427c..6cd2dacf49d 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -21,6 +21,18 @@ This module is experimental and will change significantly in the future. +.. currentmodule:: pennylane.capture + +.. autosummary:: + :toctree: api + + ~disable + ~enable + ~enabled + ~create_operator_primitive + ~create_measurement_obs_primitive + ~create_measurement_wires_primitive + ~create_measurement_mcm_primitive To activate and deactivate the new PennyLane program capturing mechanism, use the switches ``qml.capture.enable`` and ``qml.capture.disable``. @@ -114,4 +126,17 @@ def _(*args, **kwargs): """ from .switches import disable, enable, enabled from .capture_meta import CaptureMeta -from .primitives import create_operator_primitive +from .primitives import ( + create_operator_primitive, + create_measurement_obs_primitive, + create_measurement_wires_primitive, + create_measurement_mcm_primitive, +) + + +def __getattr__(key): + if key == "AbstractOperator": + from .primitives import _get_abstract_operator # pylint: disable=import-outside-toplevel + + return _get_abstract_operator() + raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'") diff --git a/pennylane/capture/capture_meta.py b/pennylane/capture/capture_meta.py index dcd4d78efa1..fe82b4926f6 100644 --- a/pennylane/capture/capture_meta.py +++ b/pennylane/capture/capture_meta.py @@ -28,6 +28,39 @@ class CaptureMeta(type): See ``pennylane/capture/explanations.md`` for more detailed information on how this technically works. + + .. code-block:: + + class AbstractMyObj(jax.core.AbstractValue): + pass + + jax.core.raise_to_shaped_mappings[AbstractMyObj] = lambda aval, _: aval + + class MyObj(metaclass=qml.capture.CaptureMeta): + + primitive = jax.core.Primitive("MyObj") + + @classmethod + def _primitive_bind_call(cls, a): + return cls.primitive.bind(a) + + def __init__(self, a): + self.a = a + + @MyObj.primitive.def_impl + def _(a): + return type.__call__(MyObj, a) + + @MyObj.primitive.def_abstract_eval + def _(a): + return AbstractMyObj() + + >>> jaxpr = jax.make_jaxpr(MyObj)(0.1) + >>> jaxpr + { lambda ; a:f32[]. let b:AbstractMyObj() = MyObj a in (b,) } + >>> jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.1) + [<__main__.MyObj at 0x17fc3ea50>] + """ @property diff --git a/pennylane/capture/primitives.py b/pennylane/capture/primitives.py index 6c568288f57..c8affc0a445 100644 --- a/pennylane/capture/primitives.py +++ b/pennylane/capture/primitives.py @@ -16,7 +16,7 @@ """ from functools import lru_cache -from typing import Optional +from typing import Callable, Optional, Tuple, Type import pennylane as qml @@ -82,9 +82,94 @@ def _pow(a, b): return AbstractOperator -def create_operator_primitive(operator_type: type) -> Optional["jax.core.Primitive"]: +@lru_cache +def _get_abstract_measurement(): + if not has_jax: # pragma: no cover + raise ImportError("Jax is required for plxpr.") # pragma: no cover + + class AbstractMeasurement(jax.core.AbstractValue): + """An abstract measurement. + + Args: + abstract_eval (Callable): See :meth:`~.MeasurementProcess._abstract_eval`. A function of + ``n_wires``, ``has_eigvals``, ``num_device_wires`` and ``shots`` that returns a shape + and numeric type. + n_wires=None (Optional[int]): the number of wires + has_eigvals=False (bool): Whether or not the measurement contains eigenvalues in a wires+eigvals + diagonal representation. + + """ + + def __init__( + self, abstract_eval: Callable, n_wires: Optional[int] = None, has_eigvals: bool = False + ): + self._abstract_eval = abstract_eval + self._n_wires = n_wires + self.has_eigvals: bool = has_eigvals + + def abstract_eval(self, num_device_wires: int, shots: int) -> Tuple[Tuple, type]: + """Calculate the shape and dtype for an evaluation with specified number of device + wires and shots. + + """ + return self._abstract_eval( + n_wires=self._n_wires, + has_eigvals=self.has_eigvals, + num_device_wires=num_device_wires, + shots=shots, + ) + + @property + def n_wires(self) -> Optional[int]: + """The number of wires for a wire based measurement. + + Options are: + * ``None``: The measurement is observable based or single mcm based + * ``0``: The measurement is broadcasted across all available devices wires + * ``int>0``: A wire or mcm based measurement with specified wires or mid circuit measurements. + + """ + return self._n_wires + + def __repr__(self): + if self.has_eigvals: + return f"AbstractMeasurement(n_wires={self.n_wires}, has_eigvals=True)" + return f"AbstractMeasurement(n_wires={self.n_wires})" + + # pylint: disable=missing-function-docstring + def at_least_vspace(self): + # TODO: investigate the proper definition of this method + raise NotImplementedError + + # pylint: disable=missing-function-docstring + def join(self, other): + # TODO: investigate the proper definition of this method + raise NotImplementedError + + # pylint: disable=missing-function-docstring + def update(self, **kwargs): + # TODO: investigate the proper definition of this method + raise NotImplementedError + + def __eq__(self, other): + return isinstance(other, AbstractMeasurement) + + def __hash__(self): + return hash("AbstractMeasurement") + + jax.core.raise_to_shaped_mappings[AbstractMeasurement] = lambda aval, _: aval + + return AbstractMeasurement + + +def create_operator_primitive( + operator_type: Type["qml.operation.Operator"], +) -> Optional["jax.core.Primitive"]: """Create a primitive corresponding to an operator type. + Called when defining any :class:`~.Operator` subclass, and is used to set the + ``Operator._primitive`` class property. + Args: operator_type (type): a subclass of qml.operation.Operator @@ -117,3 +202,122 @@ def _(*_, **__): return abstract_type() return primitive + + +def create_measurement_obs_primitive( + measurement_type: Type["qml.measurements.MeasurementProcess"], name: str +) -> Optional["jax.core.Primitive"]: + """Create a primitive corresponding to the input type where the abstract inputs are an operator. + + Called by default when defining any class inheriting from :class:`~.MeasurementProcess`, and is used to + set the ``MeasurementProcesss._obs_primitive`` property. + + Args: + measurement_type (type): a subclass of :class:`~.MeasurementProcess` + name (str): the preferred string name for the class. For example, ``"expval"``. + ``"_obs"`` is appended to this name for the name of the primitive. + + Returns: + Optional[jax.core.Primitive]: A new jax primitive. ``None`` is returned if jax is not available. + + """ + if not has_jax: + return None + + primitive = jax.core.Primitive(name + "_obs") + + @primitive.def_impl + def _(obs, **kwargs): + return type.__call__(measurement_type, obs=obs, **kwargs) + + abstract_type = _get_abstract_measurement() + + @primitive.def_abstract_eval + def _(*_, **__): + abstract_eval = measurement_type._abstract_eval # pylint: disable=protected-access + return abstract_type(abstract_eval, n_wires=None) + + return primitive + + +def create_measurement_mcm_primitive( + measurement_type: Type["qml.measurements.MeasurementProcess"], name: str +) -> Optional["jax.core.Primitive"]: + """Create a primitive corresponding to the input type where the abstract inputs are classical + mid circuit measurement results. + + Called by default when defining any class inheriting from :class:`~.MeasurementProcess`, and is used to + set the ``MeasurementProcesss._mcm_primitive`` property. + + Args: + measurement_type (type): a subclass of :class:`~.MeasurementProcess` + name (str): the preferred string name for the class. For example, ``"expval"``. + ``"_mcm"`` is appended to this name for the name of the primitive. + + Returns: + Optional[jax.core.Primitive]: A new jax primitive. ``None`` is returned if jax is not available. + """ + + if not has_jax: + return None + + primitive = jax.core.Primitive(name + "_mcm") + + @primitive.def_impl + def _(*mcms, **kwargs): + raise NotImplementedError( + "mcm measurements do not currently have a concrete implementation" + ) + # need to figure out how to convert a jaxpr int into a measurement value, and pass + # that measurment value here. + # return type.__call__(measurement_type, obs=mcms, **kwargs) + + abstract_type = _get_abstract_measurement() + + @primitive.def_abstract_eval + def _(*mcms, **__): + abstract_eval = measurement_type._abstract_eval # pylint: disable=protected-access + return abstract_type(abstract_eval, n_wires=len(mcms)) + + return primitive + + +def create_measurement_wires_primitive( + measurement_type: type, name: str +) -> Optional["jax.core.Primitive"]: + """Create a primitive corresponding to the input type where the abstract inputs are the wires. + + Called by default when defining any class inheriting from :class:`~.MeasurementProcess`, and is used to + set the ``MeasurementProcesss._wires_primitive`` property. + + Args: + measurement_type (type): a subclass of :class:`~.MeasurementProcess` + name (str): the preferred string name for the class. For example, ``"expval"``. + ``"_wires"`` is appended to this name for the name of the primitive. + + Returns: + Optional[jax.core.Primitive]: A new jax primitive. ``None`` is returned if jax is not available. + """ + if not has_jax: + return None + + primitive = jax.core.Primitive(name + "_wires") + + @primitive.def_impl + def _(*args, has_eigvals=False, **kwargs): + if has_eigvals: + wires = qml.wires.Wires(args[:-1]) + kwargs["eigvals"] = args[-1] + else: + wires = qml.wires.Wires(args) + return type.__call__(measurement_type, wires=wires, **kwargs) + + abstract_type = _get_abstract_measurement() + + @primitive.def_abstract_eval + def _(*args, has_eigvals=False, **_): + abstract_eval = measurement_type._abstract_eval # pylint: disable=protected-access + n_wires = len(args) - 1 if has_eigvals else len(args) + return abstract_type(abstract_eval, n_wires=n_wires, has_eigvals=has_eigvals) + + return primitive diff --git a/pennylane/measurements/classical_shadow.py b/pennylane/measurements/classical_shadow.py index 5335c6b3379..8f9a841a66f 100644 --- a/pennylane/measurements/classical_shadow.py +++ b/pennylane/measurements/classical_shadow.py @@ -440,6 +440,16 @@ def numeric_type(self): def return_type(self): return Shadow + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ) -> tuple: + return (2, shots, n_wires), np.int8 + def shape(self, device, shots): # pylint: disable=unused-argument # otherwise, the return type requires a device if not shots: @@ -496,6 +506,19 @@ def __init__( self.k = k super().__init__(id=id) + # pylint: disable=arguments-differ + @classmethod + def _primitive_bind_call( + cls, + H: Union[Operator, Sequence], + seed: Optional[int] = None, + k: int = 1, + **kwargs, + ): + if cls._obs_primitive is None: # pragma: no cover + return type.__call__(cls, H=H, seed=seed, k=k, **kwargs) # pragma: no cover + return cls._obs_primitive.bind(H, seed=seed, k=k, **kwargs) + def process(self, tape, device): bits, recipes = qml.classical_shadow(wires=self.wires, seed=self.seed).process(tape, device) shadow = qml.shadows.ClassicalShadow(bits, recipes, wire_map=self.wires.tolist()) @@ -573,3 +596,10 @@ def __copy__(self): k=self.k, seed=self.seed, ) + + +if ShadowExpvalMP._obs_primitive is not None: # pylint: disable=protected-access + + @ShadowExpvalMP._obs_primitive.def_impl # pylint: disable=protected-access + def _(H, **kwargs): + return type.__call__(ShadowExpvalMP, H, **kwargs) diff --git a/pennylane/measurements/counts.py b/pennylane/measurements/counts.py index 903052edae1..807be11d2e8 100644 --- a/pennylane/measurements/counts.py +++ b/pennylane/measurements/counts.py @@ -205,6 +205,18 @@ def __repr__(self): return f"CountsMP(wires={self.wires.tolist()}, all_outcomes={self.all_outcomes})" + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ) -> tuple: + raise NotImplementedError( + "CountsMP returns a dictionary, which is not compatible with capture." + ) + @property def hash(self): """int: returns an integer hash uniquely representing the measurement process""" diff --git a/pennylane/measurements/expval.py b/pennylane/measurements/expval.py index 330be5cf434..f87f7b755dc 100644 --- a/pennylane/measurements/expval.py +++ b/pennylane/measurements/expval.py @@ -91,9 +91,7 @@ class ExpectationMP(SampleMeasurement, StateMeasurement): where the instance has to be identified """ - @property - def return_type(self): - return Expectation + return_type = Expectation @property def numeric_type(self): diff --git a/pennylane/measurements/measurements.py b/pennylane/measurements/measurements.py index 8e8ac4f34eb..3397d95f03c 100644 --- a/pennylane/measurements/measurements.py +++ b/pennylane/measurements/measurements.py @@ -18,7 +18,7 @@ """ import copy import functools -from abc import ABC, abstractmethod +from abc import ABC, ABCMeta, abstractmethod from enum import Enum from typing import Optional, Sequence, Tuple, Union @@ -112,7 +112,12 @@ class MeasurementShapeError(ValueError): quantum tape.""" -class MeasurementProcess(ABC): +# pylint: disable=abstract-method +class ABCCaptureMeta(qml.capture.CaptureMeta, ABCMeta): + """A combination of the capture meta and ABCMeta""" + + +class MeasurementProcess(ABC, metaclass=ABCCaptureMeta): """Represents a measurement process occurring at the end of a quantum variational circuit. @@ -130,8 +135,84 @@ class MeasurementProcess(ABC): # pylint:disable=too-many-instance-attributes + _obs_primitive: Optional["jax.core.Primitive"] = None + _wires_primitive: Optional["jax.core.Primitive"] = None + _mcm_primitive: Optional["jax.core.Primitive"] = None + def __init_subclass__(cls, **_): register_pytree(cls, cls._flatten, cls._unflatten) + name = getattr(cls.return_type, "value", cls.__name__) + cls._wires_primitive = qml.capture.create_measurement_wires_primitive(cls, name=name) + cls._obs_primitive = qml.capture.create_measurement_obs_primitive(cls, name=name) + cls._mcm_primitive = qml.capture.create_measurement_mcm_primitive(cls, name=name) + + @classmethod + def _primitive_bind_call(cls, obs=None, wires=None, eigvals=None, id=None, **kwargs): + """Called instead of ``type.__call__`` if ``qml.capture.enabled()``. + + Measurements have three "modes": + + 1) Wires or wires + eigvals + 2) Observable + 3) Mid circuit measurements + + Not all measurements support all three modes. For example, ``VNEntropyMP`` does not + allow being specified via an observable. But we handle the generic case here. + + """ + if cls._obs_primitive is None: + # safety check if primitives aren't set correctly. + return type.__call__(cls, obs=obs, wires=wires, eigvals=eigvals, id=id, **kwargs) + if obs is None: + wires = () if wires is None else wires + if eigvals is None: + return cls._wires_primitive.bind(*wires, **kwargs) # wires + return cls._wires_primitive.bind( + *wires, eigvals, has_eigvals=True, **kwargs + ) # wires + eigvals + + if isinstance(obs, Operator) or isinstance( + getattr(obs, "aval", None), qml.capture.AbstractOperator + ): + return cls._obs_primitive.bind(obs, **kwargs) + if isinstance(obs, (list, tuple)): + return cls._mcm_primitive.bind(*obs, **kwargs) # iterable of mcms + return cls._mcm_primitive.bind(obs, **kwargs) # single mcm + + # pylint: disable=unused-argument + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ) -> tuple[tuple, type]: + """Calculate the shape and dtype that will be returned when a measurement is performed. + + This information is similar to ``numeric_type`` and ``shape``, but is provided through + a class method and does not require the creation of an instance. + + Note that ``shots`` should strictly be ``None`` or ``int``. Shot vectors are handled higher + in the stack. + + If ``n_wires is None``, then the measurement process contains an observable. An integer + ``n_wires`` can correspond either to the number of wires or to the number of mid circuit + measurements. ``n_wires = 0`` indicates a measurement that is broadcasted across all device wires. + + >>> ProbabilityMP._abstract_eval(n_wires=2) + ((4,), float) + >>> ProbabilityMP._abstract_eval(n_wires=0, num_device_wires=2) + ((4,), float) + >>> SampleMP._abstract_eval(n_wires=0, shots=50, num_device_wires=2) + ((50, 2), int) + >>> SampleMP._abstract_eval(n_wires=4, has_eigvals=True, shots=50) + ((50,), float) + >>> SampleMP._abstract_eval(n_wires=None, shots=50) + ((50,), float) + + """ + return (), float def _flatten(self): metadata = (("wires", self.raw_wires),) @@ -173,7 +254,7 @@ def __init__( self.id = id if wires is not None: - if len(wires) == 0: + if not qml.capture.enabled() and len(wires) == 0: raise ValueError("Cannot set an empty list of wires.") if obs is not None: raise ValueError("Cannot set the wires if an observable is provided.") @@ -282,12 +363,13 @@ def __hash__(self): def __repr__(self): """Representation of this class.""" + name_str = self.return_type.value if self.return_type else type(self).__name__ if self.mv: - return f"{self.return_type.value}({repr(self.mv)})" + return f"{name_str}({repr(self.mv)})" if self.obs: - return f"{self.return_type.value}({self.obs})" + return f"{name_str}({self.obs})" if self._eigvals is not None: - return f"{self.return_type.value}(eigvals={self._eigvals}, wires={self.wires.tolist()})" + return f"{name_str}(eigvals={self._eigvals}, wires={self.wires.tolist()})" # Todo: when tape is core the return type will always be taken from the MeasurementProcess return f"{getattr(self.return_type, 'value', 'None')}(wires={self.wires.tolist()})" diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index ea5a19196b3..1c927813395 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -15,7 +15,7 @@ This module contains the qml.measure measurement. """ import uuid -from typing import Generic, Optional, TypeVar +from typing import Generic, Hashable, Optional, TypeVar, Union import pennylane as qml from pennylane.wires import Wires @@ -23,7 +23,9 @@ from .measurements import MeasurementProcess, MidMeasure -def measure(wires: Wires, reset: Optional[bool] = False, postselect: Optional[int] = None): +def measure( + wires: Union[Hashable, Wires], reset: Optional[bool] = False, postselect: Optional[int] = None +): r"""Perform a mid-circuit measurement in the computational basis on the supplied qubit. @@ -207,7 +209,6 @@ def func(x): samples, leading to unexpected or incorrect results. """ - wire = Wires(wires) if len(wire) > 1: raise qml.QuantumFunctionError( @@ -217,6 +218,10 @@ def func(x): # Create a UUID and a map between MP and MV to support serialization measurement_id = str(uuid.uuid4())[:8] mp = MidMeasureMP(wires=wire, reset=reset, postselect=postselect, id=measurement_id) + if qml.capture.enabled(): + raise NotImplementedError( + "Capture cannot currently handle classical output from mid circuit measurements." + ) return MeasurementValue([mp], processing_fn=lambda v: v) @@ -257,6 +262,22 @@ def __init__( self.reset = reset self.postselect = postselect + # pylint: disable=arguments-renamed, arguments-differ + @classmethod + def _primitive_bind_call(cls, wires=None, reset=False, postselect=None, id=None): + wires = () if wires is None else wires + return cls._wires_primitive.bind(*wires, reset=reset, postselect=postselect) + + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ) -> tuple: + return (), int + def label(self, decimals=None, base_label=None, cache=None): # pylint: disable=unused-argument r"""How the mid-circuit measurement is represented in diagrams and drawings. diff --git a/pennylane/measurements/mutual_info.py b/pennylane/measurements/mutual_info.py index 88f7ed3183d..d52f282ab55 100644 --- a/pennylane/measurements/mutual_info.py +++ b/pennylane/measurements/mutual_info.py @@ -79,7 +79,9 @@ def circuit_mutual(x): wires1 = qml.wires.Wires(wires1) # the subsystems cannot overlap - if [wire for wire in wires0 if wire in wires1]: + if not any(qml.math.is_abstract(w) for w in wires0 + wires1) and [ + wire for wire in wires0 if wire in wires1 + ]: raise qml.QuantumFunctionError( "Subsystems for computing mutual information must not overlap." ) @@ -113,6 +115,14 @@ def __init__( self.log_base = log_base super().__init__(wires=wires, id=id) + # pylint: disable=arguments-differ + @classmethod + def _primitive_bind_call(cls, wires: Sequence, **kwargs): + if cls._wires_primitive is None: # pragma: no cover + # just a safety check + return type.__call__(cls, wires=wires, **kwargs) # pragma: no cover + return cls._wires_primitive.bind(*wires[0], *wires[1], n_wires0=len(wires[0]), **kwargs) + def __repr__(self): return f"MutualInfo(wires0={self.raw_wires[0].tolist()}, wires1={self.raw_wires[1].tolist()}, log_base={self.log_base})" @@ -158,3 +168,12 @@ def process_state(self, state: Sequence[complex], wire_order: Wires): c_dtype=state.dtype, base=self.log_base, ) + + +if MutualInfoMP._wires_primitive is not None: + + @MutualInfoMP._wires_primitive.def_impl + def _(*all_wires, n_wires0, **kwargs): + wires0 = all_wires[:n_wires0] + wires1 = all_wires[n_wires0:] + return type.__call__(MutualInfoMP, wires=(wires0, wires1), **kwargs) diff --git a/pennylane/measurements/probs.py b/pennylane/measurements/probs.py index a0ceb468f80..c1c0dc45c6e 100644 --- a/pennylane/measurements/probs.py +++ b/pennylane/measurements/probs.py @@ -44,7 +44,7 @@ def probs(wires=None, op=None) -> "ProbabilityMP": Args: wires (Sequence[int] or int): the wire the operation acts on - op (Observable or MeasurementValue): Observable (with a ``diagonalizing_gates`` + op (Observable or MeasurementValue or Sequence[MeasurementValue]): Observable (with a ``diagonalizing_gates`` attribute) that rotates the computational basis, or a ``MeasurementValue`` corresponding to mid-circuit measurements. @@ -103,7 +103,9 @@ def circuit(): return ProbabilityMP(obs=op) if isinstance(op, Sequence): - if not all(isinstance(o, MeasurementValue) and len(o.measurements) == 1 for o in op): + if not qml.math.is_abstract(op[0]) and not all( + isinstance(o, MeasurementValue) and len(o.measurements) == 1 for o in op + ): raise qml.QuantumFunctionError( "Only sequences of single MeasurementValues can be passed with the op argument. " "MeasurementValues manipulated using arithmetic operators cannot be used when " @@ -115,7 +117,7 @@ def circuit(): if isinstance(op, (qml.ops.Hamiltonian, qml.ops.LinearCombination)): raise qml.QuantumFunctionError("Hamiltonians are not supported for rotating probabilities.") - if op is not None and not op.has_diagonalizing_gates: + if op is not None and not qml.math.is_abstract(op) and not op.has_diagonalizing_gates: raise qml.QuantumFunctionError( f"{op} does not define diagonalizing gates : cannot be used to rotate the probability" ) @@ -147,9 +149,13 @@ class ProbabilityMP(SampleMeasurement, StateMeasurement): where the instance has to be identified """ - @property - def return_type(self): - return Probability + return_type = Probability + + @classmethod + def _abstract_eval(cls, n_wires=None, has_eigvals=False, shots=None, num_device_wires=0): + n_wires = num_device_wires if n_wires == 0 else n_wires + shape = (2**n_wires,) + return shape, float @property def numeric_type(self): diff --git a/pennylane/measurements/sample.py b/pennylane/measurements/sample.py index d36a78e4ec8..da10c55e61c 100644 --- a/pennylane/measurements/sample.py +++ b/pennylane/measurements/sample.py @@ -28,7 +28,7 @@ def sample( - op: Optional[Union[Operator, MeasurementValue]] = None, + op: Optional[Union[Operator, MeasurementValue, Sequence[MeasurementValue]]] = None, wires=None, ) -> "SampleMP": r"""Sample from the supplied observable, with the number of shots @@ -182,9 +182,34 @@ def __init__(self, obs=None, wires=None, eigvals=None, id=None): super().__init__(obs=obs, wires=wires, eigvals=eigvals, id=id) - @property - def return_type(self): - return Sample + return_type = Sample + + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ): + if shots is None: + raise ValueError("finite shots are required to use SampleMP") + sample_eigvals = n_wires is None or has_eigvals + dtype = float if sample_eigvals else int + + if n_wires == 0: + dim = num_device_wires + elif sample_eigvals: + dim = 1 + else: + dim = n_wires + + shape = [] + if shots != 1: + shape.append(shots) + if dim != 1: + shape.append(dim) + return tuple(shape), dtype @property @functools.lru_cache() diff --git a/pennylane/measurements/state.py b/pennylane/measurements/state.py index fdf37402e26..d97fcf8a971 100644 --- a/pennylane/measurements/state.py +++ b/pennylane/measurements/state.py @@ -139,9 +139,19 @@ class StateMP(StateMeasurement): def __init__(self, wires: Optional[Wires] = None, id: Optional[str] = None): super().__init__(wires=wires, id=id) - @property - def return_type(self): - return State + return_type = State + + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ): + n_wires = n_wires or num_device_wires + shape = (2**n_wires,) + return shape, complex @property def numeric_type(self): @@ -196,6 +206,18 @@ class DensityMatrixMP(StateMP): def __init__(self, wires: Wires, id: Optional[str] = None): super().__init__(wires=wires, id=id) + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ): + n_wires = n_wires or num_device_wires + shape = (2**n_wires, 2**n_wires) + return shape, complex + def shape(self, device, shots): num_shot_elements = ( sum(s.copies for s in shots.shot_vector) if shots.has_partitioned_shots else 1 diff --git a/pennylane/measurements/var.py b/pennylane/measurements/var.py index 32235ea56d7..613588ea572 100644 --- a/pennylane/measurements/var.py +++ b/pennylane/measurements/var.py @@ -83,9 +83,7 @@ class VarianceMP(SampleMeasurement, StateMeasurement): where the instance has to be identified """ - @property - def return_type(self): - return Variance + return_type = Variance @property def numeric_type(self): diff --git a/tests/capture/test_capture_module.py b/tests/capture/test_capture_module.py new file mode 100644 index 00000000000..c403fa04baa --- /dev/null +++ b/tests/capture/test_capture_module.py @@ -0,0 +1,26 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests capture module imports and access. +""" +import pytest + +import pennylane as qml + + +def test_no_attribute_available(): + """Test that if we try and access an attribute that doesn't exist, we get an attribute error.""" + + with pytest.raises(AttributeError): + _ = qml.capture.something diff --git a/tests/capture/test_measurements_capture.py b/tests/capture/test_measurements_capture.py new file mode 100644 index 00000000000..8f877bca3ab --- /dev/null +++ b/tests/capture/test_measurements_capture.py @@ -0,0 +1,661 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for capturing measurements. +""" +import numpy as np + +# pylint: disable=protected-access +import pytest + +import pennylane as qml +from pennylane.capture.primitives import _get_abstract_measurement +from pennylane.measurements import ( + ClassicalShadowMP, + DensityMatrixMP, + ExpectationMP, + MidMeasureMP, + MutualInfoMP, + ProbabilityMP, + PurityMP, + SampleMP, + ShadowExpvalMP, + StateMP, + VarianceMP, + VnEntropyMP, +) + +jax = pytest.importorskip("jax") + +pytestmark = pytest.mark.jax + +AbstractMeasurement = _get_abstract_measurement() + + +@pytest.fixture(autouse=True) +def enable_disable_plxpr(): + qml.capture.enable() + yield + qml.capture.disable() + + +def _get_shapes_for(*measurements, shots=qml.measurements.Shots(None), num_device_wires=0): + if jax.config.jax_enable_x64: + dtype_map = { + float: jax.numpy.float64, + int: jax.numpy.int64, + complex: jax.numpy.complex128, + } + else: + dtype_map = { + float: jax.numpy.float32, + int: jax.numpy.int32, + complex: jax.numpy.complex64, + } + + shapes = [] + if not shots: + shots = [None] + + for s in shots: + for m in measurements: + shape, dtype = m.abstract_eval(shots=s, num_device_wires=num_device_wires) + shapes.append(jax.core.ShapedArray(shape, dtype_map.get(dtype, dtype))) + return shapes + + +def test_abstract_measurement(): + """Tests for the AbstractMeasurement class.""" + am = AbstractMeasurement(ExpectationMP._abstract_eval, n_wires=2, has_eigvals=True) + + assert am.n_wires == 2 + assert am.has_eigvals is True + + expected_repr = "AbstractMeasurement(n_wires=2, has_eigvals=True)" + assert repr(am) == expected_repr + + assert am.abstract_eval(2, 50) == ((), float) + + with pytest.raises(NotImplementedError): + am.at_least_vspace() + + with pytest.raises(NotImplementedError): + am.join(am) + + with pytest.raises(NotImplementedError): + am.update(key="value") + + am2 = AbstractMeasurement(ExpectationMP._abstract_eval) + expected_repr2 = "AbstractMeasurement(n_wires=None)" + assert repr(am2) == expected_repr2 + + assert am == am2 + assert hash(am) == hash("AbstractMeasurement") + + +def test_counts_no_measure(): + """Test that counts can't be measured and raises a NotImplementedError.""" + + with pytest.raises(NotImplementedError, match=r"CountsMP returns a dictionary"): + qml.counts()._abstract_eval() + + +def test_mid_measure_not_implemented(): + """Test that measure raises a NotImplementedError if capture is enabled.""" + with pytest.raises(NotImplementedError): + qml.measure(0) + + +def test_primitive_none_behavior(): + """Test that if the obs primitive is None, the measurement can still + be created, but it just won't be captured into jaxpr. + """ + + # pylint: disable=too-few-public-methods + class MyMeasurement(qml.measurements.MeasurementProcess): + pass + + MyMeasurement._obs_primitive = None + + def f(): + return MyMeasurement(wires=qml.wires.Wires((0, 1))) + + mp = f() + assert isinstance(mp, MyMeasurement) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 0 + + +# pylint: disable=unnecessary-lambda +creation_funcs = [ + lambda: qml.state(), + lambda: qml.density_matrix(wires=(0, 1)), + lambda: qml.expval(qml.X(0)), + lambda: ExpectationMP(wires=qml.wires.Wires((0, 1)), eigvals=np.array([-1.0, -0.5, 0.5, 1.0])), + # lambda : qml.expval(qml.measure(0)+qml.measure(1)), + lambda: qml.var(qml.X(0)), + lambda: VarianceMP(wires=qml.wires.Wires((0, 1)), eigvals=np.array([-1.0, -0.5, 0.5, 1.0])), + # lambda : qml.var(qml.measure(0)+qml.measure(1)), + lambda: qml.probs(wires=(0, 1)), + lambda: qml.probs(op=qml.X(0)), + # lambda : qml.probs(op=[qml.measure(0), qml.measure(1)]), + lambda: ProbabilityMP(wires=qml.wires.Wires((0, 1)), eigvals=np.array([-1.0, -0.5, 0.5, 1.0])), + lambda: qml.sample(wires=(3, 4)), + lambda: qml.shadow_expval(np.array(2) * qml.X(0)), + lambda: qml.vn_entropy(wires=(1, 2)), + lambda: qml.purity(wires=(0, 1)), + lambda: qml.mutual_info(wires0=(1, 3), wires1=(2, 4), log_base=2), + lambda: qml.classical_shadow(wires=(0, 1), seed=84), + lambda: MidMeasureMP(qml.wires.Wires((0, 1))), +] + + +@pytest.mark.parametrize("func", creation_funcs) +def test_capture_and_eval(func): + """Test that captured jaxpr can be evaluated to restore the initial measurement.""" + + mp = func() + + jaxpr = jax.make_jaxpr(func)() + out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts)[0] + + assert qml.equal(mp, out) + + +@pytest.mark.parametrize("x64_mode", [True, False]) +def test_mid_measure(x64_mode): + """Test that mid circuit measurements can be captured and executed.x""" + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(w): + return MidMeasureMP(qml.wires.Wires((w,)), reset=True, postselect=1) + + jaxpr = jax.make_jaxpr(f)(2) + + assert len(jaxpr.eqns) == 1 + assert jaxpr.eqns[0].primitive == MidMeasureMP._wires_primitive + assert jaxpr.eqns[0].params == {"reset": True, "postselect": 1} + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == 1 + assert mp._abstract_eval == MidMeasureMP._abstract_eval + + shapes = _get_shapes_for(*jaxpr.out_avals, shots=qml.measurements.Shots(1)) + assert shapes[0] == jax.core.ShapedArray((), jax.numpy.int64 if x64_mode else jax.numpy.int32) + + mp = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)[0] + assert mp == f(1) + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize( + "x64_mode, expected", [(True, jax.numpy.complex128), (False, jax.numpy.complex64)] +) +@pytest.mark.parametrize("state_wires, shape", [(None, 16), (qml.wires.Wires((0, 1, 2, 3, 4)), 32)]) +def test_state(x64_mode, expected, state_wires, shape): + """Test the capture of a state measurement.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + return StateMP(wires=state_wires) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == StateMP._wires_primitive + assert len(jaxpr.eqns[0].invars) == 0 if state_wires is None else 5 + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == 0 if state_wires is None else 5 + assert mp._abstract_eval == StateMP._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, shots=qml.measurements.Shots(None), num_device_wires=4 + )[0] + assert shapes == jax.core.ShapedArray((shape,), expected) + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize( + "x64_mode, expected", [(True, jax.numpy.complex128), (False, jax.numpy.complex64)] +) +@pytest.mark.parametrize("wires, shape", [([0, 1], (4, 4)), ([], (16, 16))]) +def test_density_matrix(wires, shape, x64_mode, expected): + """Test the capture of a density matrix.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + return qml.density_matrix(wires=wires) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == DensityMatrixMP._wires_primitive + assert len(jaxpr.eqns[0].invars) == len(wires) + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == len(wires) + assert mp._abstract_eval == DensityMatrixMP._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, shots=qml.measurements.Shots(None), num_device_wires=4 + ) + assert shapes[0] == jax.core.ShapedArray(shape, expected) + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize( + "x64_mode, expected", [(True, jax.numpy.float64), (False, jax.numpy.float32)] +) +@pytest.mark.parametrize("m_type", (ExpectationMP, VarianceMP)) +class TestExpvalVar: + """Tests for capturing an expectation value or variance.""" + + def test_capture_obs(self, m_type, x64_mode, expected): + """Test that the expectation value of an observable can be captured.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + return m_type(obs=qml.X(0)) + + jaxpr = jax.make_jaxpr(f)() + + assert len(jaxpr.eqns) == 2 + assert jaxpr.eqns[0].primitive == qml.X._primitive + + assert jaxpr.eqns[1].primitive == m_type._obs_primitive + assert jaxpr.eqns[0].outvars == jaxpr.eqns[1].invars + + am = jaxpr.eqns[1].outvars[0].aval + assert isinstance(am, AbstractMeasurement) + assert am.n_wires is None + assert am._abstract_eval == m_type._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, num_device_wires=0, shots=qml.measurements.Shots(50) + )[0] + assert shapes == jax.core.ShapedArray((), expected) + jax.config.update("jax_enable_x64", initial_mode) + + def test_capture_eigvals_wires(self, m_type, x64_mode, expected): + """Test that we can capture an expectation value of eigvals+wires.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(eigs): + return m_type(eigvals=eigs, wires=qml.wires.Wires((0, 1))) + + eigs = np.array([1.0, 0.5, -0.5, -1.0]) + jaxpr = jax.make_jaxpr(f)(eigs) + + assert len(jaxpr.eqns) == 1 + assert jaxpr.eqns[0].primitive == m_type._wires_primitive + assert jaxpr.eqns[0].params == {"has_eigvals": True} + assert [x.val for x in jaxpr.eqns[0].invars[:-1]] == [0, 1] # the wires + assert jaxpr.eqns[0].invars[-1] == jaxpr.jaxpr.invars[0] # the eigvals + + am = jaxpr.eqns[0].outvars[0].aval + assert isinstance(am, AbstractMeasurement) + assert am.n_wires == 2 + assert am._abstract_eval == m_type._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, num_device_wires=0, shots=qml.measurements.Shots(50) + )[0] + assert shapes == jax.core.ShapedArray((), expected) + jax.config.update("jax_enable_x64", initial_mode) + + def test_simple_single_mcm(self, m_type, x64_mode, expected): + """Test that we can take the expectation value of a mid circuit measurement.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + # using integer to represent classical mcm value + return m_type(obs=1) + + jaxpr = jax.make_jaxpr(f)() + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == m_type._mcm_primitive + aval1 = jaxpr.eqns[0].outvars[0].aval + assert isinstance(aval1, AbstractMeasurement) + assert aval1.n_wires == 1 + assert aval1._abstract_eval == m_type._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, num_device_wires=0, shots=qml.measurements.Shots(50) + )[0] + assert shapes == jax.core.ShapedArray((), expected) + + with pytest.raises(NotImplementedError): + f() + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +class TestProbs: + + @pytest.mark.parametrize("wires, shape", [([0, 1, 2], 8), ([], 16)]) + def test_wires(self, wires, shape, x64_mode): + """Tests capturing probabilities on wires.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + return qml.probs(wires=wires) + + jaxpr = jax.make_jaxpr(f)() + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == ProbabilityMP._wires_primitive + assert [x.val for x in jaxpr.eqns[0].invars] == wires + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == len(wires) + assert mp._abstract_eval == ProbabilityMP._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, shots=qml.measurements.Shots(50), num_device_wires=4 + )[0] + assert shapes == jax.core.ShapedArray( + (shape,), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + def test_eigvals(self, x64_mode): + """Test capturing probabilities with eigenvalues.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(eigs): + return ProbabilityMP(eigvals=eigs, wires=qml.wires.Wires((0, 1))) + + eigvals = np.array([-1.0, -0.5, 0.5, 1.0]) + jaxpr = jax.make_jaxpr(f)(eigvals) + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == ProbabilityMP._wires_primitive + assert jaxpr.eqns[0].params == {"has_eigvals": True} + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == 2 + assert mp._abstract_eval == ProbabilityMP._abstract_eval + + shapes = _get_shapes_for(*jaxpr.out_avals) + assert shapes[0] == jax.core.ShapedArray( + (4,), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + def test_multiple_mcms(self, x64_mode): + """Test measuring probabilities of multiple mcms.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(c1, c2): + return qml.probs(op=[c1, c2]) + + jaxpr = jax.make_jaxpr(f)(1, 2) + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == ProbabilityMP._mcm_primitive + out = jaxpr.eqns[0].outvars[0].aval + assert isinstance(out, AbstractMeasurement) + assert out.n_wires == 2 + assert out._abstract_eval == ProbabilityMP._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, shots=qml.measurements.Shots(50), num_device_wires=6 + ) + assert shapes[0] == jax.core.ShapedArray( + (4,), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + with pytest.raises(NotImplementedError): + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1, 2) + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +class TestSample: + + @pytest.mark.parametrize("wires, dim1_len", [([0, 1, 2], 3), ([], 4)]) + def test_wires(self, wires, dim1_len, x64_mode): + """Tests capturing samples on wires.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(*inner_wires): + return qml.sample(wires=inner_wires) + + jaxpr = jax.make_jaxpr(f)(*wires) + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == SampleMP._wires_primitive + assert [x.aval for x in jaxpr.eqns[0].invars] == jaxpr.in_avals + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == len(wires) + assert mp._abstract_eval == SampleMP._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, shots=qml.measurements.Shots(50), num_device_wires=4 + ) + assert shapes[0] == jax.core.ShapedArray( + (50, dim1_len), jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) + + with pytest.raises(ValueError, match="finite shots are required"): + jaxpr.out_avals[0].abstract_eval(shots=None, num_device_wires=4) + + jax.config.update("jax_enable_x64", initial_mode) + + def test_eigvals(self, x64_mode): + """Test capturing samples with eigenvalues.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(eigs): + return SampleMP(eigvals=eigs, wires=qml.wires.Wires((0, 1))) + + eigvals = np.array([-1.0, -0.5, 0.5, 1.0]) + jaxpr = jax.make_jaxpr(f)(eigvals) + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == SampleMP._wires_primitive + assert jaxpr.eqns[0].params == {"has_eigvals": True} + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == 2 + assert mp._abstract_eval == SampleMP._abstract_eval + + shapes = _get_shapes_for(*jaxpr.out_avals, shots=qml.measurements.Shots(50)) + assert shapes[0] == jax.core.ShapedArray( + (50,), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + def test_multiple_mcms(self, x64_mode): + """Test sampling from multiple mcms.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + return qml.sample(op=[1, 2]) + + jaxpr = jax.make_jaxpr(f)() + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == SampleMP._mcm_primitive + out = jaxpr.eqns[0].outvars[0].aval + assert isinstance(out, AbstractMeasurement) + assert out.n_wires == 2 + assert out._abstract_eval == SampleMP._abstract_eval + + shapes = _get_shapes_for(*jaxpr.out_avals, shots=qml.measurements.Shots(50)) + assert shapes[0] == jax.core.ShapedArray( + (50, 2), jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) + + with pytest.raises(NotImplementedError): + f() + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +def test_shadow_expval(x64_mode): + """Test that the shadow expval of an observable can be captured.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + return qml.shadow_expval(qml.X(0), seed=887, k=4) + + jaxpr = jax.make_jaxpr(f)() + + assert len(jaxpr.eqns) == 2 + assert jaxpr.eqns[0].primitive == qml.X._primitive + + assert jaxpr.eqns[1].primitive == ShadowExpvalMP._obs_primitive + assert jaxpr.eqns[0].outvars == jaxpr.eqns[1].invars + assert jaxpr.eqns[1].params == {"seed": 887, "k": 4} + + am = jaxpr.eqns[1].outvars[0].aval + assert isinstance(am, AbstractMeasurement) + assert am.n_wires is None + assert am._abstract_eval == ShadowExpvalMP._abstract_eval + + shapes = _get_shapes_for(*jaxpr.out_avals, num_device_wires=0, shots=qml.measurements.Shots(50)) + assert shapes[0] == jax.core.ShapedArray( + (), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +@pytest.mark.parametrize("mtype, kwargs", [(VnEntropyMP, {"log_base": 2}), (PurityMP, {})]) +def test_qinfo_measurements(mtype, kwargs, x64_mode): + """Test the capture of a vn entropy and purity measurement.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(w1, w2): + return mtype(wires=qml.wires.Wires([w1, w2]), **kwargs) + + jaxpr = jax.make_jaxpr(f)(1, 2) + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == mtype._wires_primitive + assert jaxpr.eqns[0].params == kwargs + assert len(jaxpr.eqns[0].invars) == 2 + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == 2 + assert mp._abstract_eval == mtype._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, num_device_wires=4, shots=qml.measurements.Shots(None) + ) + assert shapes[0] == jax.core.ShapedArray( + (), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +def test_MutualInfo(x64_mode): + """Test the capture of a vn entropy and purity measurement.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(w1, w2): + return qml.mutual_info(wires0=[w1, 1], wires1=[w2, 3], log_base=2) + + jaxpr = jax.make_jaxpr(f)(0, 2) + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == MutualInfoMP._wires_primitive + assert jaxpr.eqns[0].params == {"log_base": 2, "n_wires0": 2} + assert len(jaxpr.eqns[0].invars) == 4 + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp._abstract_eval == MutualInfoMP._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, num_device_wires=4, shots=qml.measurements.Shots(None) + ) + assert shapes[0] == jax.core.ShapedArray( + (), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + +def test_ClassicalShadow(): + """Test that the classical shadow measurement can be captured.""" + + def f(): + return qml.classical_shadow(wires=(0, 1, 2), seed=95) + + jaxpr = jax.make_jaxpr(f)() + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == ClassicalShadowMP._wires_primitive + assert jaxpr.eqns[0].params == {"seed": 95} + assert len(jaxpr.eqns[0].invars) == 3 + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == 3 + assert mp._abstract_eval == ClassicalShadowMP._abstract_eval + + shapes = _get_shapes_for(*jaxpr.out_avals, num_device_wires=4, shots=qml.measurements.Shots(50)) + assert shapes[0] == jax.core.ShapedArray((2, 50, 3), np.int8)