Skip to content

Commit

Permalink
Capture measurements with jax.make_jaxpr. (#5564)
Browse files Browse the repository at this point in the history
**Context:**

**Description of the Change:**

**Benefits:**

We can capture measurements and keep track of their resulting shapes in
an extensible manner.

**Possible Drawbacks:**

Measurements do need a lot more hand-holding then observables due to the
"duel mode" inputs of either wires or observables, and the need to fully
specifying the resulting shape when we perform an actual measurment.

With the current framework, it will theoretically be extensible to add
new measurements, but the process is slightly more error prone.

We also now provide duplicate information from
`MeasurementProcess.numeric_type`, `MeasurementProcess.shape`, and
`MeasurementProcess._abstract_eval`. But the instance-based property and
method didn't play well with the need to track that information during
an abstract evaluation.

**Related GitHub Issues:**

[sc-61200] [sc-59452]

---------

Co-authored-by: dwierichs <david.wierichs@xanadu.ai>
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
  • Loading branch information
3 people committed May 30, 2024
1 parent 9d6a846 commit 60d2b5a
Show file tree
Hide file tree
Showing 16 changed files with 1,196 additions and 34 deletions.
4 changes: 2 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@
`m = measure(0); qml.sample(m)`.
[(#5673)](https://github.com/PennyLaneAI/pennylane/pull/5673)

* PennyLane operators can now automatically be captured as instructions in JAXPR. See the experimental
`capture` module for more information.
* PennyLane operators and measurements can now automatically be captured as instructions in JAXPR.
[(#5564)](https://github.com/PennyLaneAI/pennylane/pull/5564)
[(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511)

* The `decompose` transform has an `error` kwarg to specify the type of error that should be raised,
Expand Down
27 changes: 26 additions & 1 deletion pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
This module is experimental and will change significantly in the future.
.. currentmodule:: pennylane.capture
.. autosummary::
:toctree: api
~disable
~enable
~enabled
~create_operator_primitive
~create_measurement_obs_primitive
~create_measurement_wires_primitive
~create_measurement_mcm_primitive
To activate and deactivate the new PennyLane program capturing mechanism, use
the switches ``qml.capture.enable`` and ``qml.capture.disable``.
Expand Down Expand Up @@ -114,4 +126,17 @@ def _(*args, **kwargs):
"""
from .switches import disable, enable, enabled
from .capture_meta import CaptureMeta
from .primitives import create_operator_primitive
from .primitives import (
create_operator_primitive,
create_measurement_obs_primitive,
create_measurement_wires_primitive,
create_measurement_mcm_primitive,
)


def __getattr__(key):
if key == "AbstractOperator":
from .primitives import _get_abstract_operator # pylint: disable=import-outside-toplevel

return _get_abstract_operator()
raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'")
33 changes: 33 additions & 0 deletions pennylane/capture/capture_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,39 @@ class CaptureMeta(type):
See ``pennylane/capture/explanations.md`` for more detailed information on how this technically
works.
.. code-block::
class AbstractMyObj(jax.core.AbstractValue):
pass
jax.core.raise_to_shaped_mappings[AbstractMyObj] = lambda aval, _: aval
class MyObj(metaclass=qml.capture.CaptureMeta):
primitive = jax.core.Primitive("MyObj")
@classmethod
def _primitive_bind_call(cls, a):
return cls.primitive.bind(a)
def __init__(self, a):
self.a = a
@MyObj.primitive.def_impl
def _(a):
return type.__call__(MyObj, a)
@MyObj.primitive.def_abstract_eval
def _(a):
return AbstractMyObj()
>>> jaxpr = jax.make_jaxpr(MyObj)(0.1)
>>> jaxpr
{ lambda ; a:f32[]. let b:AbstractMyObj() = MyObj a in (b,) }
>>> jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.1)
[<__main__.MyObj at 0x17fc3ea50>]
"""

@property
Expand Down
208 changes: 206 additions & 2 deletions pennylane/capture/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

from functools import lru_cache
from typing import Optional
from typing import Callable, Optional, Tuple, Type

import pennylane as qml

Expand Down Expand Up @@ -82,9 +82,94 @@ def _pow(a, b):
return AbstractOperator


def create_operator_primitive(operator_type: type) -> Optional["jax.core.Primitive"]:
@lru_cache
def _get_abstract_measurement():
if not has_jax: # pragma: no cover
raise ImportError("Jax is required for plxpr.") # pragma: no cover

class AbstractMeasurement(jax.core.AbstractValue):
"""An abstract measurement.
Args:
abstract_eval (Callable): See :meth:`~.MeasurementProcess._abstract_eval`. A function of
``n_wires``, ``has_eigvals``, ``num_device_wires`` and ``shots`` that returns a shape
and numeric type.
n_wires=None (Optional[int]): the number of wires
has_eigvals=False (bool): Whether or not the measurement contains eigenvalues in a wires+eigvals
diagonal representation.
"""

def __init__(
self, abstract_eval: Callable, n_wires: Optional[int] = None, has_eigvals: bool = False
):
self._abstract_eval = abstract_eval
self._n_wires = n_wires
self.has_eigvals: bool = has_eigvals

def abstract_eval(self, num_device_wires: int, shots: int) -> Tuple[Tuple, type]:
"""Calculate the shape and dtype for an evaluation with specified number of device
wires and shots.
"""
return self._abstract_eval(
n_wires=self._n_wires,
has_eigvals=self.has_eigvals,
num_device_wires=num_device_wires,
shots=shots,
)

@property
def n_wires(self) -> Optional[int]:
"""The number of wires for a wire based measurement.
Options are:
* ``None``: The measurement is observable based or single mcm based
* ``0``: The measurement is broadcasted across all available devices wires
* ``int>0``: A wire or mcm based measurement with specified wires or mid circuit measurements.
"""
return self._n_wires

def __repr__(self):
if self.has_eigvals:
return f"AbstractMeasurement(n_wires={self.n_wires}, has_eigvals=True)"
return f"AbstractMeasurement(n_wires={self.n_wires})"

# pylint: disable=missing-function-docstring
def at_least_vspace(self):
# TODO: investigate the proper definition of this method
raise NotImplementedError

# pylint: disable=missing-function-docstring
def join(self, other):
# TODO: investigate the proper definition of this method
raise NotImplementedError

# pylint: disable=missing-function-docstring
def update(self, **kwargs):
# TODO: investigate the proper definition of this method
raise NotImplementedError

def __eq__(self, other):
return isinstance(other, AbstractMeasurement)

def __hash__(self):
return hash("AbstractMeasurement")

jax.core.raise_to_shaped_mappings[AbstractMeasurement] = lambda aval, _: aval

return AbstractMeasurement


def create_operator_primitive(
operator_type: Type["qml.operation.Operator"],
) -> Optional["jax.core.Primitive"]:
"""Create a primitive corresponding to an operator type.
Called when defining any :class:`~.Operator` subclass, and is used to set the
``Operator._primitive`` class property.
Args:
operator_type (type): a subclass of qml.operation.Operator
Expand Down Expand Up @@ -117,3 +202,122 @@ def _(*_, **__):
return abstract_type()

return primitive


def create_measurement_obs_primitive(
measurement_type: Type["qml.measurements.MeasurementProcess"], name: str
) -> Optional["jax.core.Primitive"]:
"""Create a primitive corresponding to the input type where the abstract inputs are an operator.
Called by default when defining any class inheriting from :class:`~.MeasurementProcess`, and is used to
set the ``MeasurementProcesss._obs_primitive`` property.
Args:
measurement_type (type): a subclass of :class:`~.MeasurementProcess`
name (str): the preferred string name for the class. For example, ``"expval"``.
``"_obs"`` is appended to this name for the name of the primitive.
Returns:
Optional[jax.core.Primitive]: A new jax primitive. ``None`` is returned if jax is not available.
"""
if not has_jax:
return None

primitive = jax.core.Primitive(name + "_obs")

@primitive.def_impl
def _(obs, **kwargs):
return type.__call__(measurement_type, obs=obs, **kwargs)

abstract_type = _get_abstract_measurement()

@primitive.def_abstract_eval
def _(*_, **__):
abstract_eval = measurement_type._abstract_eval # pylint: disable=protected-access
return abstract_type(abstract_eval, n_wires=None)

return primitive


def create_measurement_mcm_primitive(
measurement_type: Type["qml.measurements.MeasurementProcess"], name: str
) -> Optional["jax.core.Primitive"]:
"""Create a primitive corresponding to the input type where the abstract inputs are classical
mid circuit measurement results.
Called by default when defining any class inheriting from :class:`~.MeasurementProcess`, and is used to
set the ``MeasurementProcesss._mcm_primitive`` property.
Args:
measurement_type (type): a subclass of :class:`~.MeasurementProcess`
name (str): the preferred string name for the class. For example, ``"expval"``.
``"_mcm"`` is appended to this name for the name of the primitive.
Returns:
Optional[jax.core.Primitive]: A new jax primitive. ``None`` is returned if jax is not available.
"""

if not has_jax:
return None

primitive = jax.core.Primitive(name + "_mcm")

@primitive.def_impl
def _(*mcms, **kwargs):
raise NotImplementedError(
"mcm measurements do not currently have a concrete implementation"
)
# need to figure out how to convert a jaxpr int into a measurement value, and pass
# that measurment value here.
# return type.__call__(measurement_type, obs=mcms, **kwargs)

abstract_type = _get_abstract_measurement()

@primitive.def_abstract_eval
def _(*mcms, **__):
abstract_eval = measurement_type._abstract_eval # pylint: disable=protected-access
return abstract_type(abstract_eval, n_wires=len(mcms))

return primitive


def create_measurement_wires_primitive(
measurement_type: type, name: str
) -> Optional["jax.core.Primitive"]:
"""Create a primitive corresponding to the input type where the abstract inputs are the wires.
Called by default when defining any class inheriting from :class:`~.MeasurementProcess`, and is used to
set the ``MeasurementProcesss._wires_primitive`` property.
Args:
measurement_type (type): a subclass of :class:`~.MeasurementProcess`
name (str): the preferred string name for the class. For example, ``"expval"``.
``"_wires"`` is appended to this name for the name of the primitive.
Returns:
Optional[jax.core.Primitive]: A new jax primitive. ``None`` is returned if jax is not available.
"""
if not has_jax:
return None

primitive = jax.core.Primitive(name + "_wires")

@primitive.def_impl
def _(*args, has_eigvals=False, **kwargs):
if has_eigvals:
wires = qml.wires.Wires(args[:-1])
kwargs["eigvals"] = args[-1]
else:
wires = qml.wires.Wires(args)
return type.__call__(measurement_type, wires=wires, **kwargs)

abstract_type = _get_abstract_measurement()

@primitive.def_abstract_eval
def _(*args, has_eigvals=False, **_):
abstract_eval = measurement_type._abstract_eval # pylint: disable=protected-access
n_wires = len(args) - 1 if has_eigvals else len(args)
return abstract_type(abstract_eval, n_wires=n_wires, has_eigvals=has_eigvals)

return primitive
30 changes: 30 additions & 0 deletions pennylane/measurements/classical_shadow.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,16 @@ def numeric_type(self):
def return_type(self):
return Shadow

@classmethod
def _abstract_eval(
cls,
n_wires: Optional[int] = None,
has_eigvals=False,
shots: Optional[int] = None,
num_device_wires: int = 0,
) -> tuple:
return (2, shots, n_wires), np.int8

def shape(self, device, shots): # pylint: disable=unused-argument
# otherwise, the return type requires a device
if not shots:
Expand Down Expand Up @@ -496,6 +506,19 @@ def __init__(
self.k = k
super().__init__(id=id)

# pylint: disable=arguments-differ
@classmethod
def _primitive_bind_call(
cls,
H: Union[Operator, Sequence],
seed: Optional[int] = None,
k: int = 1,
**kwargs,
):
if cls._obs_primitive is None: # pragma: no cover
return type.__call__(cls, H=H, seed=seed, k=k, **kwargs) # pragma: no cover
return cls._obs_primitive.bind(H, seed=seed, k=k, **kwargs)

def process(self, tape, device):
bits, recipes = qml.classical_shadow(wires=self.wires, seed=self.seed).process(tape, device)
shadow = qml.shadows.ClassicalShadow(bits, recipes, wire_map=self.wires.tolist())
Expand Down Expand Up @@ -573,3 +596,10 @@ def __copy__(self):
k=self.k,
seed=self.seed,
)


if ShadowExpvalMP._obs_primitive is not None: # pylint: disable=protected-access

@ShadowExpvalMP._obs_primitive.def_impl # pylint: disable=protected-access
def _(H, **kwargs):
return type.__call__(ShadowExpvalMP, H, **kwargs)
12 changes: 12 additions & 0 deletions pennylane/measurements/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,18 @@ def __repr__(self):

return f"CountsMP(wires={self.wires.tolist()}, all_outcomes={self.all_outcomes})"

@classmethod
def _abstract_eval(
cls,
n_wires: Optional[int] = None,
has_eigvals=False,
shots: Optional[int] = None,
num_device_wires: int = 0,
) -> tuple:
raise NotImplementedError(
"CountsMP returns a dictionary, which is not compatible with capture."
)

@property
def hash(self):
"""int: returns an integer hash uniquely representing the measurement process"""
Expand Down
4 changes: 1 addition & 3 deletions pennylane/measurements/expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ class ExpectationMP(SampleMeasurement, StateMeasurement):
where the instance has to be identified
"""

@property
def return_type(self):
return Expectation
return_type = Expectation

@property
def numeric_type(self):
Expand Down
Loading

0 comments on commit 60d2b5a

Please sign in to comment.