diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 0448d074a7f..0e4fc15314c 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -93,8 +93,8 @@ `m = measure(0); qml.sample(m)`. [(#5673)](https://github.com/PennyLaneAI/pennylane/pull/5673) -* PennyLane operators can now automatically be captured as instructions in JAXPR. See the experimental - `capture` module for more information. +* PennyLane operators and measurements can now automatically be captured as instructions in JAXPR. + [(#5564)](https://github.com/PennyLaneAI/pennylane/pull/5564) [(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511) * The `decompose` transform has an `error` kwarg to specify the type of error that should be raised, diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index a3a35dd427c..6cd2dacf49d 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -21,6 +21,18 @@ This module is experimental and will change significantly in the future. +.. currentmodule:: pennylane.capture + +.. autosummary:: + :toctree: api + + ~disable + ~enable + ~enabled + ~create_operator_primitive + ~create_measurement_obs_primitive + ~create_measurement_wires_primitive + ~create_measurement_mcm_primitive To activate and deactivate the new PennyLane program capturing mechanism, use the switches ``qml.capture.enable`` and ``qml.capture.disable``. @@ -114,4 +126,17 @@ def _(*args, **kwargs): """ from .switches import disable, enable, enabled from .capture_meta import CaptureMeta -from .primitives import create_operator_primitive +from .primitives import ( + create_operator_primitive, + create_measurement_obs_primitive, + create_measurement_wires_primitive, + create_measurement_mcm_primitive, +) + + +def __getattr__(key): + if key == "AbstractOperator": + from .primitives import _get_abstract_operator # pylint: disable=import-outside-toplevel + + return _get_abstract_operator() + raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'") diff --git a/pennylane/capture/capture_meta.py b/pennylane/capture/capture_meta.py index dcd4d78efa1..fe82b4926f6 100644 --- a/pennylane/capture/capture_meta.py +++ b/pennylane/capture/capture_meta.py @@ -28,6 +28,39 @@ class CaptureMeta(type): See ``pennylane/capture/explanations.md`` for more detailed information on how this technically works. + + .. code-block:: + + class AbstractMyObj(jax.core.AbstractValue): + pass + + jax.core.raise_to_shaped_mappings[AbstractMyObj] = lambda aval, _: aval + + class MyObj(metaclass=qml.capture.CaptureMeta): + + primitive = jax.core.Primitive("MyObj") + + @classmethod + def _primitive_bind_call(cls, a): + return cls.primitive.bind(a) + + def __init__(self, a): + self.a = a + + @MyObj.primitive.def_impl + def _(a): + return type.__call__(MyObj, a) + + @MyObj.primitive.def_abstract_eval + def _(a): + return AbstractMyObj() + + >>> jaxpr = jax.make_jaxpr(MyObj)(0.1) + >>> jaxpr + { lambda ; a:f32[]. let b:AbstractMyObj() = MyObj a in (b,) } + >>> jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.1) + [<__main__.MyObj at 0x17fc3ea50>] + """ @property diff --git a/pennylane/capture/primitives.py b/pennylane/capture/primitives.py index 6c568288f57..c8affc0a445 100644 --- a/pennylane/capture/primitives.py +++ b/pennylane/capture/primitives.py @@ -16,7 +16,7 @@ """ from functools import lru_cache -from typing import Optional +from typing import Callable, Optional, Tuple, Type import pennylane as qml @@ -82,9 +82,94 @@ def _pow(a, b): return AbstractOperator -def create_operator_primitive(operator_type: type) -> Optional["jax.core.Primitive"]: +@lru_cache +def _get_abstract_measurement(): + if not has_jax: # pragma: no cover + raise ImportError("Jax is required for plxpr.") # pragma: no cover + + class AbstractMeasurement(jax.core.AbstractValue): + """An abstract measurement. + + Args: + abstract_eval (Callable): See :meth:`~.MeasurementProcess._abstract_eval`. A function of + ``n_wires``, ``has_eigvals``, ``num_device_wires`` and ``shots`` that returns a shape + and numeric type. + n_wires=None (Optional[int]): the number of wires + has_eigvals=False (bool): Whether or not the measurement contains eigenvalues in a wires+eigvals + diagonal representation. + + """ + + def __init__( + self, abstract_eval: Callable, n_wires: Optional[int] = None, has_eigvals: bool = False + ): + self._abstract_eval = abstract_eval + self._n_wires = n_wires + self.has_eigvals: bool = has_eigvals + + def abstract_eval(self, num_device_wires: int, shots: int) -> Tuple[Tuple, type]: + """Calculate the shape and dtype for an evaluation with specified number of device + wires and shots. + + """ + return self._abstract_eval( + n_wires=self._n_wires, + has_eigvals=self.has_eigvals, + num_device_wires=num_device_wires, + shots=shots, + ) + + @property + def n_wires(self) -> Optional[int]: + """The number of wires for a wire based measurement. + + Options are: + * ``None``: The measurement is observable based or single mcm based + * ``0``: The measurement is broadcasted across all available devices wires + * ``int>0``: A wire or mcm based measurement with specified wires or mid circuit measurements. + + """ + return self._n_wires + + def __repr__(self): + if self.has_eigvals: + return f"AbstractMeasurement(n_wires={self.n_wires}, has_eigvals=True)" + return f"AbstractMeasurement(n_wires={self.n_wires})" + + # pylint: disable=missing-function-docstring + def at_least_vspace(self): + # TODO: investigate the proper definition of this method + raise NotImplementedError + + # pylint: disable=missing-function-docstring + def join(self, other): + # TODO: investigate the proper definition of this method + raise NotImplementedError + + # pylint: disable=missing-function-docstring + def update(self, **kwargs): + # TODO: investigate the proper definition of this method + raise NotImplementedError + + def __eq__(self, other): + return isinstance(other, AbstractMeasurement) + + def __hash__(self): + return hash("AbstractMeasurement") + + jax.core.raise_to_shaped_mappings[AbstractMeasurement] = lambda aval, _: aval + + return AbstractMeasurement + + +def create_operator_primitive( + operator_type: Type["qml.operation.Operator"], +) -> Optional["jax.core.Primitive"]: """Create a primitive corresponding to an operator type. + Called when defining any :class:`~.Operator` subclass, and is used to set the + ``Operator._primitive`` class property. + Args: operator_type (type): a subclass of qml.operation.Operator @@ -117,3 +202,122 @@ def _(*_, **__): return abstract_type() return primitive + + +def create_measurement_obs_primitive( + measurement_type: Type["qml.measurements.MeasurementProcess"], name: str +) -> Optional["jax.core.Primitive"]: + """Create a primitive corresponding to the input type where the abstract inputs are an operator. + + Called by default when defining any class inheriting from :class:`~.MeasurementProcess`, and is used to + set the ``MeasurementProcesss._obs_primitive`` property. + + Args: + measurement_type (type): a subclass of :class:`~.MeasurementProcess` + name (str): the preferred string name for the class. For example, ``"expval"``. + ``"_obs"`` is appended to this name for the name of the primitive. + + Returns: + Optional[jax.core.Primitive]: A new jax primitive. ``None`` is returned if jax is not available. + + """ + if not has_jax: + return None + + primitive = jax.core.Primitive(name + "_obs") + + @primitive.def_impl + def _(obs, **kwargs): + return type.__call__(measurement_type, obs=obs, **kwargs) + + abstract_type = _get_abstract_measurement() + + @primitive.def_abstract_eval + def _(*_, **__): + abstract_eval = measurement_type._abstract_eval # pylint: disable=protected-access + return abstract_type(abstract_eval, n_wires=None) + + return primitive + + +def create_measurement_mcm_primitive( + measurement_type: Type["qml.measurements.MeasurementProcess"], name: str +) -> Optional["jax.core.Primitive"]: + """Create a primitive corresponding to the input type where the abstract inputs are classical + mid circuit measurement results. + + Called by default when defining any class inheriting from :class:`~.MeasurementProcess`, and is used to + set the ``MeasurementProcesss._mcm_primitive`` property. + + Args: + measurement_type (type): a subclass of :class:`~.MeasurementProcess` + name (str): the preferred string name for the class. For example, ``"expval"``. + ``"_mcm"`` is appended to this name for the name of the primitive. + + Returns: + Optional[jax.core.Primitive]: A new jax primitive. ``None`` is returned if jax is not available. + """ + + if not has_jax: + return None + + primitive = jax.core.Primitive(name + "_mcm") + + @primitive.def_impl + def _(*mcms, **kwargs): + raise NotImplementedError( + "mcm measurements do not currently have a concrete implementation" + ) + # need to figure out how to convert a jaxpr int into a measurement value, and pass + # that measurment value here. + # return type.__call__(measurement_type, obs=mcms, **kwargs) + + abstract_type = _get_abstract_measurement() + + @primitive.def_abstract_eval + def _(*mcms, **__): + abstract_eval = measurement_type._abstract_eval # pylint: disable=protected-access + return abstract_type(abstract_eval, n_wires=len(mcms)) + + return primitive + + +def create_measurement_wires_primitive( + measurement_type: type, name: str +) -> Optional["jax.core.Primitive"]: + """Create a primitive corresponding to the input type where the abstract inputs are the wires. + + Called by default when defining any class inheriting from :class:`~.MeasurementProcess`, and is used to + set the ``MeasurementProcesss._wires_primitive`` property. + + Args: + measurement_type (type): a subclass of :class:`~.MeasurementProcess` + name (str): the preferred string name for the class. For example, ``"expval"``. + ``"_wires"`` is appended to this name for the name of the primitive. + + Returns: + Optional[jax.core.Primitive]: A new jax primitive. ``None`` is returned if jax is not available. + """ + if not has_jax: + return None + + primitive = jax.core.Primitive(name + "_wires") + + @primitive.def_impl + def _(*args, has_eigvals=False, **kwargs): + if has_eigvals: + wires = qml.wires.Wires(args[:-1]) + kwargs["eigvals"] = args[-1] + else: + wires = qml.wires.Wires(args) + return type.__call__(measurement_type, wires=wires, **kwargs) + + abstract_type = _get_abstract_measurement() + + @primitive.def_abstract_eval + def _(*args, has_eigvals=False, **_): + abstract_eval = measurement_type._abstract_eval # pylint: disable=protected-access + n_wires = len(args) - 1 if has_eigvals else len(args) + return abstract_type(abstract_eval, n_wires=n_wires, has_eigvals=has_eigvals) + + return primitive diff --git a/pennylane/measurements/classical_shadow.py b/pennylane/measurements/classical_shadow.py index 5335c6b3379..8f9a841a66f 100644 --- a/pennylane/measurements/classical_shadow.py +++ b/pennylane/measurements/classical_shadow.py @@ -440,6 +440,16 @@ def numeric_type(self): def return_type(self): return Shadow + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ) -> tuple: + return (2, shots, n_wires), np.int8 + def shape(self, device, shots): # pylint: disable=unused-argument # otherwise, the return type requires a device if not shots: @@ -496,6 +506,19 @@ def __init__( self.k = k super().__init__(id=id) + # pylint: disable=arguments-differ + @classmethod + def _primitive_bind_call( + cls, + H: Union[Operator, Sequence], + seed: Optional[int] = None, + k: int = 1, + **kwargs, + ): + if cls._obs_primitive is None: # pragma: no cover + return type.__call__(cls, H=H, seed=seed, k=k, **kwargs) # pragma: no cover + return cls._obs_primitive.bind(H, seed=seed, k=k, **kwargs) + def process(self, tape, device): bits, recipes = qml.classical_shadow(wires=self.wires, seed=self.seed).process(tape, device) shadow = qml.shadows.ClassicalShadow(bits, recipes, wire_map=self.wires.tolist()) @@ -573,3 +596,10 @@ def __copy__(self): k=self.k, seed=self.seed, ) + + +if ShadowExpvalMP._obs_primitive is not None: # pylint: disable=protected-access + + @ShadowExpvalMP._obs_primitive.def_impl # pylint: disable=protected-access + def _(H, **kwargs): + return type.__call__(ShadowExpvalMP, H, **kwargs) diff --git a/pennylane/measurements/counts.py b/pennylane/measurements/counts.py index 903052edae1..807be11d2e8 100644 --- a/pennylane/measurements/counts.py +++ b/pennylane/measurements/counts.py @@ -205,6 +205,18 @@ def __repr__(self): return f"CountsMP(wires={self.wires.tolist()}, all_outcomes={self.all_outcomes})" + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ) -> tuple: + raise NotImplementedError( + "CountsMP returns a dictionary, which is not compatible with capture." + ) + @property def hash(self): """int: returns an integer hash uniquely representing the measurement process""" diff --git a/pennylane/measurements/expval.py b/pennylane/measurements/expval.py index 330be5cf434..f87f7b755dc 100644 --- a/pennylane/measurements/expval.py +++ b/pennylane/measurements/expval.py @@ -91,9 +91,7 @@ class ExpectationMP(SampleMeasurement, StateMeasurement): where the instance has to be identified """ - @property - def return_type(self): - return Expectation + return_type = Expectation @property def numeric_type(self): diff --git a/pennylane/measurements/measurements.py b/pennylane/measurements/measurements.py index 8e8ac4f34eb..3397d95f03c 100644 --- a/pennylane/measurements/measurements.py +++ b/pennylane/measurements/measurements.py @@ -18,7 +18,7 @@ """ import copy import functools -from abc import ABC, abstractmethod +from abc import ABC, ABCMeta, abstractmethod from enum import Enum from typing import Optional, Sequence, Tuple, Union @@ -112,7 +112,12 @@ class MeasurementShapeError(ValueError): quantum tape.""" -class MeasurementProcess(ABC): +# pylint: disable=abstract-method +class ABCCaptureMeta(qml.capture.CaptureMeta, ABCMeta): + """A combination of the capture meta and ABCMeta""" + + +class MeasurementProcess(ABC, metaclass=ABCCaptureMeta): """Represents a measurement process occurring at the end of a quantum variational circuit. @@ -130,8 +135,84 @@ class MeasurementProcess(ABC): # pylint:disable=too-many-instance-attributes + _obs_primitive: Optional["jax.core.Primitive"] = None + _wires_primitive: Optional["jax.core.Primitive"] = None + _mcm_primitive: Optional["jax.core.Primitive"] = None + def __init_subclass__(cls, **_): register_pytree(cls, cls._flatten, cls._unflatten) + name = getattr(cls.return_type, "value", cls.__name__) + cls._wires_primitive = qml.capture.create_measurement_wires_primitive(cls, name=name) + cls._obs_primitive = qml.capture.create_measurement_obs_primitive(cls, name=name) + cls._mcm_primitive = qml.capture.create_measurement_mcm_primitive(cls, name=name) + + @classmethod + def _primitive_bind_call(cls, obs=None, wires=None, eigvals=None, id=None, **kwargs): + """Called instead of ``type.__call__`` if ``qml.capture.enabled()``. + + Measurements have three "modes": + + 1) Wires or wires + eigvals + 2) Observable + 3) Mid circuit measurements + + Not all measurements support all three modes. For example, ``VNEntropyMP`` does not + allow being specified via an observable. But we handle the generic case here. + + """ + if cls._obs_primitive is None: + # safety check if primitives aren't set correctly. + return type.__call__(cls, obs=obs, wires=wires, eigvals=eigvals, id=id, **kwargs) + if obs is None: + wires = () if wires is None else wires + if eigvals is None: + return cls._wires_primitive.bind(*wires, **kwargs) # wires + return cls._wires_primitive.bind( + *wires, eigvals, has_eigvals=True, **kwargs + ) # wires + eigvals + + if isinstance(obs, Operator) or isinstance( + getattr(obs, "aval", None), qml.capture.AbstractOperator + ): + return cls._obs_primitive.bind(obs, **kwargs) + if isinstance(obs, (list, tuple)): + return cls._mcm_primitive.bind(*obs, **kwargs) # iterable of mcms + return cls._mcm_primitive.bind(obs, **kwargs) # single mcm + + # pylint: disable=unused-argument + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ) -> tuple[tuple, type]: + """Calculate the shape and dtype that will be returned when a measurement is performed. + + This information is similar to ``numeric_type`` and ``shape``, but is provided through + a class method and does not require the creation of an instance. + + Note that ``shots`` should strictly be ``None`` or ``int``. Shot vectors are handled higher + in the stack. + + If ``n_wires is None``, then the measurement process contains an observable. An integer + ``n_wires`` can correspond either to the number of wires or to the number of mid circuit + measurements. ``n_wires = 0`` indicates a measurement that is broadcasted across all device wires. + + >>> ProbabilityMP._abstract_eval(n_wires=2) + ((4,), float) + >>> ProbabilityMP._abstract_eval(n_wires=0, num_device_wires=2) + ((4,), float) + >>> SampleMP._abstract_eval(n_wires=0, shots=50, num_device_wires=2) + ((50, 2), int) + >>> SampleMP._abstract_eval(n_wires=4, has_eigvals=True, shots=50) + ((50,), float) + >>> SampleMP._abstract_eval(n_wires=None, shots=50) + ((50,), float) + + """ + return (), float def _flatten(self): metadata = (("wires", self.raw_wires),) @@ -173,7 +254,7 @@ def __init__( self.id = id if wires is not None: - if len(wires) == 0: + if not qml.capture.enabled() and len(wires) == 0: raise ValueError("Cannot set an empty list of wires.") if obs is not None: raise ValueError("Cannot set the wires if an observable is provided.") @@ -282,12 +363,13 @@ def __hash__(self): def __repr__(self): """Representation of this class.""" + name_str = self.return_type.value if self.return_type else type(self).__name__ if self.mv: - return f"{self.return_type.value}({repr(self.mv)})" + return f"{name_str}({repr(self.mv)})" if self.obs: - return f"{self.return_type.value}({self.obs})" + return f"{name_str}({self.obs})" if self._eigvals is not None: - return f"{self.return_type.value}(eigvals={self._eigvals}, wires={self.wires.tolist()})" + return f"{name_str}(eigvals={self._eigvals}, wires={self.wires.tolist()})" # Todo: when tape is core the return type will always be taken from the MeasurementProcess return f"{getattr(self.return_type, 'value', 'None')}(wires={self.wires.tolist()})" diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index ea5a19196b3..1c927813395 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -15,7 +15,7 @@ This module contains the qml.measure measurement. """ import uuid -from typing import Generic, Optional, TypeVar +from typing import Generic, Hashable, Optional, TypeVar, Union import pennylane as qml from pennylane.wires import Wires @@ -23,7 +23,9 @@ from .measurements import MeasurementProcess, MidMeasure -def measure(wires: Wires, reset: Optional[bool] = False, postselect: Optional[int] = None): +def measure( + wires: Union[Hashable, Wires], reset: Optional[bool] = False, postselect: Optional[int] = None +): r"""Perform a mid-circuit measurement in the computational basis on the supplied qubit. @@ -207,7 +209,6 @@ def func(x): samples, leading to unexpected or incorrect results. """ - wire = Wires(wires) if len(wire) > 1: raise qml.QuantumFunctionError( @@ -217,6 +218,10 @@ def func(x): # Create a UUID and a map between MP and MV to support serialization measurement_id = str(uuid.uuid4())[:8] mp = MidMeasureMP(wires=wire, reset=reset, postselect=postselect, id=measurement_id) + if qml.capture.enabled(): + raise NotImplementedError( + "Capture cannot currently handle classical output from mid circuit measurements." + ) return MeasurementValue([mp], processing_fn=lambda v: v) @@ -257,6 +262,22 @@ def __init__( self.reset = reset self.postselect = postselect + # pylint: disable=arguments-renamed, arguments-differ + @classmethod + def _primitive_bind_call(cls, wires=None, reset=False, postselect=None, id=None): + wires = () if wires is None else wires + return cls._wires_primitive.bind(*wires, reset=reset, postselect=postselect) + + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ) -> tuple: + return (), int + def label(self, decimals=None, base_label=None, cache=None): # pylint: disable=unused-argument r"""How the mid-circuit measurement is represented in diagrams and drawings. diff --git a/pennylane/measurements/mutual_info.py b/pennylane/measurements/mutual_info.py index 88f7ed3183d..d52f282ab55 100644 --- a/pennylane/measurements/mutual_info.py +++ b/pennylane/measurements/mutual_info.py @@ -79,7 +79,9 @@ def circuit_mutual(x): wires1 = qml.wires.Wires(wires1) # the subsystems cannot overlap - if [wire for wire in wires0 if wire in wires1]: + if not any(qml.math.is_abstract(w) for w in wires0 + wires1) and [ + wire for wire in wires0 if wire in wires1 + ]: raise qml.QuantumFunctionError( "Subsystems for computing mutual information must not overlap." ) @@ -113,6 +115,14 @@ def __init__( self.log_base = log_base super().__init__(wires=wires, id=id) + # pylint: disable=arguments-differ + @classmethod + def _primitive_bind_call(cls, wires: Sequence, **kwargs): + if cls._wires_primitive is None: # pragma: no cover + # just a safety check + return type.__call__(cls, wires=wires, **kwargs) # pragma: no cover + return cls._wires_primitive.bind(*wires[0], *wires[1], n_wires0=len(wires[0]), **kwargs) + def __repr__(self): return f"MutualInfo(wires0={self.raw_wires[0].tolist()}, wires1={self.raw_wires[1].tolist()}, log_base={self.log_base})" @@ -158,3 +168,12 @@ def process_state(self, state: Sequence[complex], wire_order: Wires): c_dtype=state.dtype, base=self.log_base, ) + + +if MutualInfoMP._wires_primitive is not None: + + @MutualInfoMP._wires_primitive.def_impl + def _(*all_wires, n_wires0, **kwargs): + wires0 = all_wires[:n_wires0] + wires1 = all_wires[n_wires0:] + return type.__call__(MutualInfoMP, wires=(wires0, wires1), **kwargs) diff --git a/pennylane/measurements/probs.py b/pennylane/measurements/probs.py index a0ceb468f80..c1c0dc45c6e 100644 --- a/pennylane/measurements/probs.py +++ b/pennylane/measurements/probs.py @@ -44,7 +44,7 @@ def probs(wires=None, op=None) -> "ProbabilityMP": Args: wires (Sequence[int] or int): the wire the operation acts on - op (Observable or MeasurementValue): Observable (with a ``diagonalizing_gates`` + op (Observable or MeasurementValue or Sequence[MeasurementValue]): Observable (with a ``diagonalizing_gates`` attribute) that rotates the computational basis, or a ``MeasurementValue`` corresponding to mid-circuit measurements. @@ -103,7 +103,9 @@ def circuit(): return ProbabilityMP(obs=op) if isinstance(op, Sequence): - if not all(isinstance(o, MeasurementValue) and len(o.measurements) == 1 for o in op): + if not qml.math.is_abstract(op[0]) and not all( + isinstance(o, MeasurementValue) and len(o.measurements) == 1 for o in op + ): raise qml.QuantumFunctionError( "Only sequences of single MeasurementValues can be passed with the op argument. " "MeasurementValues manipulated using arithmetic operators cannot be used when " @@ -115,7 +117,7 @@ def circuit(): if isinstance(op, (qml.ops.Hamiltonian, qml.ops.LinearCombination)): raise qml.QuantumFunctionError("Hamiltonians are not supported for rotating probabilities.") - if op is not None and not op.has_diagonalizing_gates: + if op is not None and not qml.math.is_abstract(op) and not op.has_diagonalizing_gates: raise qml.QuantumFunctionError( f"{op} does not define diagonalizing gates : cannot be used to rotate the probability" ) @@ -147,9 +149,13 @@ class ProbabilityMP(SampleMeasurement, StateMeasurement): where the instance has to be identified """ - @property - def return_type(self): - return Probability + return_type = Probability + + @classmethod + def _abstract_eval(cls, n_wires=None, has_eigvals=False, shots=None, num_device_wires=0): + n_wires = num_device_wires if n_wires == 0 else n_wires + shape = (2**n_wires,) + return shape, float @property def numeric_type(self): diff --git a/pennylane/measurements/sample.py b/pennylane/measurements/sample.py index d36a78e4ec8..da10c55e61c 100644 --- a/pennylane/measurements/sample.py +++ b/pennylane/measurements/sample.py @@ -28,7 +28,7 @@ def sample( - op: Optional[Union[Operator, MeasurementValue]] = None, + op: Optional[Union[Operator, MeasurementValue, Sequence[MeasurementValue]]] = None, wires=None, ) -> "SampleMP": r"""Sample from the supplied observable, with the number of shots @@ -182,9 +182,34 @@ def __init__(self, obs=None, wires=None, eigvals=None, id=None): super().__init__(obs=obs, wires=wires, eigvals=eigvals, id=id) - @property - def return_type(self): - return Sample + return_type = Sample + + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ): + if shots is None: + raise ValueError("finite shots are required to use SampleMP") + sample_eigvals = n_wires is None or has_eigvals + dtype = float if sample_eigvals else int + + if n_wires == 0: + dim = num_device_wires + elif sample_eigvals: + dim = 1 + else: + dim = n_wires + + shape = [] + if shots != 1: + shape.append(shots) + if dim != 1: + shape.append(dim) + return tuple(shape), dtype @property @functools.lru_cache() diff --git a/pennylane/measurements/state.py b/pennylane/measurements/state.py index fdf37402e26..d97fcf8a971 100644 --- a/pennylane/measurements/state.py +++ b/pennylane/measurements/state.py @@ -139,9 +139,19 @@ class StateMP(StateMeasurement): def __init__(self, wires: Optional[Wires] = None, id: Optional[str] = None): super().__init__(wires=wires, id=id) - @property - def return_type(self): - return State + return_type = State + + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ): + n_wires = n_wires or num_device_wires + shape = (2**n_wires,) + return shape, complex @property def numeric_type(self): @@ -196,6 +206,18 @@ class DensityMatrixMP(StateMP): def __init__(self, wires: Wires, id: Optional[str] = None): super().__init__(wires=wires, id=id) + @classmethod + def _abstract_eval( + cls, + n_wires: Optional[int] = None, + has_eigvals=False, + shots: Optional[int] = None, + num_device_wires: int = 0, + ): + n_wires = n_wires or num_device_wires + shape = (2**n_wires, 2**n_wires) + return shape, complex + def shape(self, device, shots): num_shot_elements = ( sum(s.copies for s in shots.shot_vector) if shots.has_partitioned_shots else 1 diff --git a/pennylane/measurements/var.py b/pennylane/measurements/var.py index 32235ea56d7..613588ea572 100644 --- a/pennylane/measurements/var.py +++ b/pennylane/measurements/var.py @@ -83,9 +83,7 @@ class VarianceMP(SampleMeasurement, StateMeasurement): where the instance has to be identified """ - @property - def return_type(self): - return Variance + return_type = Variance @property def numeric_type(self): diff --git a/tests/capture/test_capture_module.py b/tests/capture/test_capture_module.py new file mode 100644 index 00000000000..c403fa04baa --- /dev/null +++ b/tests/capture/test_capture_module.py @@ -0,0 +1,26 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests capture module imports and access. +""" +import pytest + +import pennylane as qml + + +def test_no_attribute_available(): + """Test that if we try and access an attribute that doesn't exist, we get an attribute error.""" + + with pytest.raises(AttributeError): + _ = qml.capture.something diff --git a/tests/capture/test_measurements_capture.py b/tests/capture/test_measurements_capture.py new file mode 100644 index 00000000000..8f877bca3ab --- /dev/null +++ b/tests/capture/test_measurements_capture.py @@ -0,0 +1,661 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for capturing measurements. +""" +import numpy as np + +# pylint: disable=protected-access +import pytest + +import pennylane as qml +from pennylane.capture.primitives import _get_abstract_measurement +from pennylane.measurements import ( + ClassicalShadowMP, + DensityMatrixMP, + ExpectationMP, + MidMeasureMP, + MutualInfoMP, + ProbabilityMP, + PurityMP, + SampleMP, + ShadowExpvalMP, + StateMP, + VarianceMP, + VnEntropyMP, +) + +jax = pytest.importorskip("jax") + +pytestmark = pytest.mark.jax + +AbstractMeasurement = _get_abstract_measurement() + + +@pytest.fixture(autouse=True) +def enable_disable_plxpr(): + qml.capture.enable() + yield + qml.capture.disable() + + +def _get_shapes_for(*measurements, shots=qml.measurements.Shots(None), num_device_wires=0): + if jax.config.jax_enable_x64: + dtype_map = { + float: jax.numpy.float64, + int: jax.numpy.int64, + complex: jax.numpy.complex128, + } + else: + dtype_map = { + float: jax.numpy.float32, + int: jax.numpy.int32, + complex: jax.numpy.complex64, + } + + shapes = [] + if not shots: + shots = [None] + + for s in shots: + for m in measurements: + shape, dtype = m.abstract_eval(shots=s, num_device_wires=num_device_wires) + shapes.append(jax.core.ShapedArray(shape, dtype_map.get(dtype, dtype))) + return shapes + + +def test_abstract_measurement(): + """Tests for the AbstractMeasurement class.""" + am = AbstractMeasurement(ExpectationMP._abstract_eval, n_wires=2, has_eigvals=True) + + assert am.n_wires == 2 + assert am.has_eigvals is True + + expected_repr = "AbstractMeasurement(n_wires=2, has_eigvals=True)" + assert repr(am) == expected_repr + + assert am.abstract_eval(2, 50) == ((), float) + + with pytest.raises(NotImplementedError): + am.at_least_vspace() + + with pytest.raises(NotImplementedError): + am.join(am) + + with pytest.raises(NotImplementedError): + am.update(key="value") + + am2 = AbstractMeasurement(ExpectationMP._abstract_eval) + expected_repr2 = "AbstractMeasurement(n_wires=None)" + assert repr(am2) == expected_repr2 + + assert am == am2 + assert hash(am) == hash("AbstractMeasurement") + + +def test_counts_no_measure(): + """Test that counts can't be measured and raises a NotImplementedError.""" + + with pytest.raises(NotImplementedError, match=r"CountsMP returns a dictionary"): + qml.counts()._abstract_eval() + + +def test_mid_measure_not_implemented(): + """Test that measure raises a NotImplementedError if capture is enabled.""" + with pytest.raises(NotImplementedError): + qml.measure(0) + + +def test_primitive_none_behavior(): + """Test that if the obs primitive is None, the measurement can still + be created, but it just won't be captured into jaxpr. + """ + + # pylint: disable=too-few-public-methods + class MyMeasurement(qml.measurements.MeasurementProcess): + pass + + MyMeasurement._obs_primitive = None + + def f(): + return MyMeasurement(wires=qml.wires.Wires((0, 1))) + + mp = f() + assert isinstance(mp, MyMeasurement) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 0 + + +# pylint: disable=unnecessary-lambda +creation_funcs = [ + lambda: qml.state(), + lambda: qml.density_matrix(wires=(0, 1)), + lambda: qml.expval(qml.X(0)), + lambda: ExpectationMP(wires=qml.wires.Wires((0, 1)), eigvals=np.array([-1.0, -0.5, 0.5, 1.0])), + # lambda : qml.expval(qml.measure(0)+qml.measure(1)), + lambda: qml.var(qml.X(0)), + lambda: VarianceMP(wires=qml.wires.Wires((0, 1)), eigvals=np.array([-1.0, -0.5, 0.5, 1.0])), + # lambda : qml.var(qml.measure(0)+qml.measure(1)), + lambda: qml.probs(wires=(0, 1)), + lambda: qml.probs(op=qml.X(0)), + # lambda : qml.probs(op=[qml.measure(0), qml.measure(1)]), + lambda: ProbabilityMP(wires=qml.wires.Wires((0, 1)), eigvals=np.array([-1.0, -0.5, 0.5, 1.0])), + lambda: qml.sample(wires=(3, 4)), + lambda: qml.shadow_expval(np.array(2) * qml.X(0)), + lambda: qml.vn_entropy(wires=(1, 2)), + lambda: qml.purity(wires=(0, 1)), + lambda: qml.mutual_info(wires0=(1, 3), wires1=(2, 4), log_base=2), + lambda: qml.classical_shadow(wires=(0, 1), seed=84), + lambda: MidMeasureMP(qml.wires.Wires((0, 1))), +] + + +@pytest.mark.parametrize("func", creation_funcs) +def test_capture_and_eval(func): + """Test that captured jaxpr can be evaluated to restore the initial measurement.""" + + mp = func() + + jaxpr = jax.make_jaxpr(func)() + out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts)[0] + + assert qml.equal(mp, out) + + +@pytest.mark.parametrize("x64_mode", [True, False]) +def test_mid_measure(x64_mode): + """Test that mid circuit measurements can be captured and executed.x""" + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(w): + return MidMeasureMP(qml.wires.Wires((w,)), reset=True, postselect=1) + + jaxpr = jax.make_jaxpr(f)(2) + + assert len(jaxpr.eqns) == 1 + assert jaxpr.eqns[0].primitive == MidMeasureMP._wires_primitive + assert jaxpr.eqns[0].params == {"reset": True, "postselect": 1} + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == 1 + assert mp._abstract_eval == MidMeasureMP._abstract_eval + + shapes = _get_shapes_for(*jaxpr.out_avals, shots=qml.measurements.Shots(1)) + assert shapes[0] == jax.core.ShapedArray((), jax.numpy.int64 if x64_mode else jax.numpy.int32) + + mp = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)[0] + assert mp == f(1) + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize( + "x64_mode, expected", [(True, jax.numpy.complex128), (False, jax.numpy.complex64)] +) +@pytest.mark.parametrize("state_wires, shape", [(None, 16), (qml.wires.Wires((0, 1, 2, 3, 4)), 32)]) +def test_state(x64_mode, expected, state_wires, shape): + """Test the capture of a state measurement.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + return StateMP(wires=state_wires) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == StateMP._wires_primitive + assert len(jaxpr.eqns[0].invars) == 0 if state_wires is None else 5 + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == 0 if state_wires is None else 5 + assert mp._abstract_eval == StateMP._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, shots=qml.measurements.Shots(None), num_device_wires=4 + )[0] + assert shapes == jax.core.ShapedArray((shape,), expected) + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize( + "x64_mode, expected", [(True, jax.numpy.complex128), (False, jax.numpy.complex64)] +) +@pytest.mark.parametrize("wires, shape", [([0, 1], (4, 4)), ([], (16, 16))]) +def test_density_matrix(wires, shape, x64_mode, expected): + """Test the capture of a density matrix.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + return qml.density_matrix(wires=wires) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == DensityMatrixMP._wires_primitive + assert len(jaxpr.eqns[0].invars) == len(wires) + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == len(wires) + assert mp._abstract_eval == DensityMatrixMP._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, shots=qml.measurements.Shots(None), num_device_wires=4 + ) + assert shapes[0] == jax.core.ShapedArray(shape, expected) + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize( + "x64_mode, expected", [(True, jax.numpy.float64), (False, jax.numpy.float32)] +) +@pytest.mark.parametrize("m_type", (ExpectationMP, VarianceMP)) +class TestExpvalVar: + """Tests for capturing an expectation value or variance.""" + + def test_capture_obs(self, m_type, x64_mode, expected): + """Test that the expectation value of an observable can be captured.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + return m_type(obs=qml.X(0)) + + jaxpr = jax.make_jaxpr(f)() + + assert len(jaxpr.eqns) == 2 + assert jaxpr.eqns[0].primitive == qml.X._primitive + + assert jaxpr.eqns[1].primitive == m_type._obs_primitive + assert jaxpr.eqns[0].outvars == jaxpr.eqns[1].invars + + am = jaxpr.eqns[1].outvars[0].aval + assert isinstance(am, AbstractMeasurement) + assert am.n_wires is None + assert am._abstract_eval == m_type._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, num_device_wires=0, shots=qml.measurements.Shots(50) + )[0] + assert shapes == jax.core.ShapedArray((), expected) + jax.config.update("jax_enable_x64", initial_mode) + + def test_capture_eigvals_wires(self, m_type, x64_mode, expected): + """Test that we can capture an expectation value of eigvals+wires.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(eigs): + return m_type(eigvals=eigs, wires=qml.wires.Wires((0, 1))) + + eigs = np.array([1.0, 0.5, -0.5, -1.0]) + jaxpr = jax.make_jaxpr(f)(eigs) + + assert len(jaxpr.eqns) == 1 + assert jaxpr.eqns[0].primitive == m_type._wires_primitive + assert jaxpr.eqns[0].params == {"has_eigvals": True} + assert [x.val for x in jaxpr.eqns[0].invars[:-1]] == [0, 1] # the wires + assert jaxpr.eqns[0].invars[-1] == jaxpr.jaxpr.invars[0] # the eigvals + + am = jaxpr.eqns[0].outvars[0].aval + assert isinstance(am, AbstractMeasurement) + assert am.n_wires == 2 + assert am._abstract_eval == m_type._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, num_device_wires=0, shots=qml.measurements.Shots(50) + )[0] + assert shapes == jax.core.ShapedArray((), expected) + jax.config.update("jax_enable_x64", initial_mode) + + def test_simple_single_mcm(self, m_type, x64_mode, expected): + """Test that we can take the expectation value of a mid circuit measurement.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + # using integer to represent classical mcm value + return m_type(obs=1) + + jaxpr = jax.make_jaxpr(f)() + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == m_type._mcm_primitive + aval1 = jaxpr.eqns[0].outvars[0].aval + assert isinstance(aval1, AbstractMeasurement) + assert aval1.n_wires == 1 + assert aval1._abstract_eval == m_type._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, num_device_wires=0, shots=qml.measurements.Shots(50) + )[0] + assert shapes == jax.core.ShapedArray((), expected) + + with pytest.raises(NotImplementedError): + f() + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +class TestProbs: + + @pytest.mark.parametrize("wires, shape", [([0, 1, 2], 8), ([], 16)]) + def test_wires(self, wires, shape, x64_mode): + """Tests capturing probabilities on wires.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + return qml.probs(wires=wires) + + jaxpr = jax.make_jaxpr(f)() + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == ProbabilityMP._wires_primitive + assert [x.val for x in jaxpr.eqns[0].invars] == wires + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == len(wires) + assert mp._abstract_eval == ProbabilityMP._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, shots=qml.measurements.Shots(50), num_device_wires=4 + )[0] + assert shapes == jax.core.ShapedArray( + (shape,), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + def test_eigvals(self, x64_mode): + """Test capturing probabilities with eigenvalues.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(eigs): + return ProbabilityMP(eigvals=eigs, wires=qml.wires.Wires((0, 1))) + + eigvals = np.array([-1.0, -0.5, 0.5, 1.0]) + jaxpr = jax.make_jaxpr(f)(eigvals) + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == ProbabilityMP._wires_primitive + assert jaxpr.eqns[0].params == {"has_eigvals": True} + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == 2 + assert mp._abstract_eval == ProbabilityMP._abstract_eval + + shapes = _get_shapes_for(*jaxpr.out_avals) + assert shapes[0] == jax.core.ShapedArray( + (4,), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + def test_multiple_mcms(self, x64_mode): + """Test measuring probabilities of multiple mcms.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(c1, c2): + return qml.probs(op=[c1, c2]) + + jaxpr = jax.make_jaxpr(f)(1, 2) + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == ProbabilityMP._mcm_primitive + out = jaxpr.eqns[0].outvars[0].aval + assert isinstance(out, AbstractMeasurement) + assert out.n_wires == 2 + assert out._abstract_eval == ProbabilityMP._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, shots=qml.measurements.Shots(50), num_device_wires=6 + ) + assert shapes[0] == jax.core.ShapedArray( + (4,), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + with pytest.raises(NotImplementedError): + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1, 2) + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +class TestSample: + + @pytest.mark.parametrize("wires, dim1_len", [([0, 1, 2], 3), ([], 4)]) + def test_wires(self, wires, dim1_len, x64_mode): + """Tests capturing samples on wires.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(*inner_wires): + return qml.sample(wires=inner_wires) + + jaxpr = jax.make_jaxpr(f)(*wires) + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == SampleMP._wires_primitive + assert [x.aval for x in jaxpr.eqns[0].invars] == jaxpr.in_avals + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == len(wires) + assert mp._abstract_eval == SampleMP._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, shots=qml.measurements.Shots(50), num_device_wires=4 + ) + assert shapes[0] == jax.core.ShapedArray( + (50, dim1_len), jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) + + with pytest.raises(ValueError, match="finite shots are required"): + jaxpr.out_avals[0].abstract_eval(shots=None, num_device_wires=4) + + jax.config.update("jax_enable_x64", initial_mode) + + def test_eigvals(self, x64_mode): + """Test capturing samples with eigenvalues.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(eigs): + return SampleMP(eigvals=eigs, wires=qml.wires.Wires((0, 1))) + + eigvals = np.array([-1.0, -0.5, 0.5, 1.0]) + jaxpr = jax.make_jaxpr(f)(eigvals) + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == SampleMP._wires_primitive + assert jaxpr.eqns[0].params == {"has_eigvals": True} + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == 2 + assert mp._abstract_eval == SampleMP._abstract_eval + + shapes = _get_shapes_for(*jaxpr.out_avals, shots=qml.measurements.Shots(50)) + assert shapes[0] == jax.core.ShapedArray( + (50,), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + def test_multiple_mcms(self, x64_mode): + """Test sampling from multiple mcms.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + return qml.sample(op=[1, 2]) + + jaxpr = jax.make_jaxpr(f)() + + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == SampleMP._mcm_primitive + out = jaxpr.eqns[0].outvars[0].aval + assert isinstance(out, AbstractMeasurement) + assert out.n_wires == 2 + assert out._abstract_eval == SampleMP._abstract_eval + + shapes = _get_shapes_for(*jaxpr.out_avals, shots=qml.measurements.Shots(50)) + assert shapes[0] == jax.core.ShapedArray( + (50, 2), jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) + + with pytest.raises(NotImplementedError): + f() + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +def test_shadow_expval(x64_mode): + """Test that the shadow expval of an observable can be captured.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(): + return qml.shadow_expval(qml.X(0), seed=887, k=4) + + jaxpr = jax.make_jaxpr(f)() + + assert len(jaxpr.eqns) == 2 + assert jaxpr.eqns[0].primitive == qml.X._primitive + + assert jaxpr.eqns[1].primitive == ShadowExpvalMP._obs_primitive + assert jaxpr.eqns[0].outvars == jaxpr.eqns[1].invars + assert jaxpr.eqns[1].params == {"seed": 887, "k": 4} + + am = jaxpr.eqns[1].outvars[0].aval + assert isinstance(am, AbstractMeasurement) + assert am.n_wires is None + assert am._abstract_eval == ShadowExpvalMP._abstract_eval + + shapes = _get_shapes_for(*jaxpr.out_avals, num_device_wires=0, shots=qml.measurements.Shots(50)) + assert shapes[0] == jax.core.ShapedArray( + (), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +@pytest.mark.parametrize("mtype, kwargs", [(VnEntropyMP, {"log_base": 2}), (PurityMP, {})]) +def test_qinfo_measurements(mtype, kwargs, x64_mode): + """Test the capture of a vn entropy and purity measurement.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(w1, w2): + return mtype(wires=qml.wires.Wires([w1, w2]), **kwargs) + + jaxpr = jax.make_jaxpr(f)(1, 2) + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == mtype._wires_primitive + assert jaxpr.eqns[0].params == kwargs + assert len(jaxpr.eqns[0].invars) == 2 + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == 2 + assert mp._abstract_eval == mtype._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, num_device_wires=4, shots=qml.measurements.Shots(None) + ) + assert shapes[0] == jax.core.ShapedArray( + (), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +def test_MutualInfo(x64_mode): + """Test the capture of a vn entropy and purity measurement.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + def f(w1, w2): + return qml.mutual_info(wires0=[w1, 1], wires1=[w2, 3], log_base=2) + + jaxpr = jax.make_jaxpr(f)(0, 2) + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == MutualInfoMP._wires_primitive + assert jaxpr.eqns[0].params == {"log_base": 2, "n_wires0": 2} + assert len(jaxpr.eqns[0].invars) == 4 + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp._abstract_eval == MutualInfoMP._abstract_eval + + shapes = _get_shapes_for( + *jaxpr.out_avals, num_device_wires=4, shots=qml.measurements.Shots(None) + ) + assert shapes[0] == jax.core.ShapedArray( + (), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + +def test_ClassicalShadow(): + """Test that the classical shadow measurement can be captured.""" + + def f(): + return qml.classical_shadow(wires=(0, 1, 2), seed=95) + + jaxpr = jax.make_jaxpr(f)() + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 1 + + assert jaxpr.eqns[0].primitive == ClassicalShadowMP._wires_primitive + assert jaxpr.eqns[0].params == {"seed": 95} + assert len(jaxpr.eqns[0].invars) == 3 + mp = jaxpr.eqns[0].outvars[0].aval + assert isinstance(mp, AbstractMeasurement) + assert mp.n_wires == 3 + assert mp._abstract_eval == ClassicalShadowMP._abstract_eval + + shapes = _get_shapes_for(*jaxpr.out_avals, num_device_wires=4, shots=qml.measurements.Shots(50)) + assert shapes[0] == jax.core.ShapedArray((2, 50, 3), np.int8)