Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Capture a qnode into jaxpr #5708

Merged
merged 109 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 103 commits
Commits
Show all changes
109 commits
Select commit Hold shift + click to select a range
dd1f380
first pass
albi3ro Apr 12, 2024
83e4fe7
add module
dwierichs Apr 12, 2024
22fbadc
changelog
dwierichs Apr 12, 2024
156a749
Merge branch 'master' into add-capture-module
dwierichs Apr 12, 2024
ad7b637
import, fix
dwierichs Apr 12, 2024
d9b8b8a
git
dwierichs Apr 12, 2024
1634916
tests
dwierichs Apr 12, 2024
f6ba19d
Merge branch 'master' into add-capture-module
dwierichs Apr 12, 2024
15559d6
move switches to switches.py
dwierichs Apr 12, 2024
c511fb5
lint
dwierichs Apr 12, 2024
add82e9
identify all operators as jax primitives
albi3ro Apr 12, 2024
242cad1
Merge branch 'add-capture-module' into plxpr-capture-operations
albi3ro Apr 12, 2024
4c6c7bb
add dunder math support
albi3ro Apr 12, 2024
1d927b8
allow overriding primmitive bind call
albi3ro Apr 15, 2024
67a2069
improving testing
albi3ro Apr 15, 2024
45ab87a
fix up to allow using abc still
albi3ro Apr 15, 2024
c8de208
Update pennylane/capture/meta_type.py
albi3ro Apr 15, 2024
fe4ebcf
Update pennylane/operation.py
albi3ro Apr 15, 2024
27c78b8
adding some more documentation
albi3ro Apr 16, 2024
00c09e0
Merge branch 'plxpr-capture-operations' of https://github.com/PennyLa…
albi3ro Apr 16, 2024
68020ca
Merge branch 'master' into plxpr-capture-operations
dwierichs Apr 16, 2024
9747e1e
Apply suggestions from code review
albi3ro Apr 16, 2024
21d12e0
pow support
albi3ro Apr 16, 2024
58367bc
Merge branch 'plxpr-capture-operations' of https://github.com/PennyLa…
albi3ro Apr 16, 2024
c83a42a
minor fixes
albi3ro Apr 16, 2024
712a0cf
fix pauli rot
albi3ro Apr 16, 2024
f2fbe31
Update pennylane/capture/__init__.py
albi3ro Apr 17, 2024
8c2a4eb
responding to feedback
albi3ro Apr 17, 2024
121fbbe
move metaclass initialization to __init_subclass__
albi3ro Apr 18, 2024
cf8f2aa
responding to feedback, changelog
albi3ro Apr 19, 2024
136a5cd
improve testing for evaluating the jaxpr
albi3ro Apr 22, 2024
4ce8720
Merge branch 'master' into plxpr-capture-operations
albi3ro Apr 22, 2024
1ec85ab
automatically capture measurements with plxpr
albi3ro Apr 22, 2024
427a78c
fixes
albi3ro Apr 23, 2024
e35f602
Apply suggestions from code review
albi3ro Apr 23, 2024
d6bd42d
[skip ci] responding to feedback and polishing
albi3ro Apr 23, 2024
afd22f6
Merge branch 'master' into plxpr-capture-operations
albi3ro Apr 23, 2024
8b03ab0
Merge branch 'master' into plxpr-capture-operations
albi3ro Apr 24, 2024
a1bb1f9
Merge branch 'master' into plxpr-capture-operations
albi3ro Apr 25, 2024
c1c26c4
Update pennylane/capture/meta_type.py
albi3ro Apr 25, 2024
3d47a27
Update pennylane/measurements/sample.py
albi3ro Apr 30, 2024
2a2fb72
Update pennylane/capture/meta_type.py
albi3ro Apr 30, 2024
e25cb25
Update pennylane/capture/meta_type.py
albi3ro Apr 30, 2024
3b5cdc9
[skip ci] rename to CaptureMeta, create primitives file
albi3ro May 3, 2024
2b70e77
Merge branch 'master' into plxpr-capture-operations
albi3ro May 6, 2024
1c03b95
add source code clarification
albi3ro May 6, 2024
3792377
Merge branch 'plxpr-capture-operations' of https://github.com/PennyLa…
albi3ro May 6, 2024
9bc466f
changelog
albi3ro May 6, 2024
8a1fead
Update tests/capture/test_operators.py
albi3ro May 6, 2024
43aba1d
Update pennylane/capture/__init__.py
albi3ro May 7, 2024
d2864a4
Merge branch 'master' into plxpr-capture-operations
albi3ro May 7, 2024
e269c9a
Update tests/capture/test_operators.py
albi3ro May 7, 2024
41a24ef
Update tests/capture/test_operators.py
albi3ro May 7, 2024
4f74c34
Apply suggestions from code review
albi3ro May 7, 2024
867d54f
responding to feedback
albi3ro May 7, 2024
8796db8
merge
albi3ro May 7, 2024
3220f10
final code review responses
albi3ro May 7, 2024
c0be87f
minor fixes
albi3ro May 8, 2024
ac7ea0a
Update pennylane/capture/primitives.py
albi3ro May 9, 2024
66a9d7d
Merge branch 'master' into plxpr-capture-operations
albi3ro May 9, 2024
63b5f4b
pylint
albi3ro May 9, 2024
d190afe
Merge branch 'plxpr-capture-operations' of https://github.com/PennyLa…
albi3ro May 9, 2024
2fa7799
remove trailing whitespace
albi3ro May 9, 2024
a1d5d92
Merge branch 'plxpr-capture-measurements' of https://github.com/Penny…
albi3ro May 9, 2024
4040fc8
merging
albi3ro May 9, 2024
3c8ed40
minor fixes
albi3ro May 10, 2024
391d77e
Merge branch 'master' into plxpr-capture-measurements
albi3ro May 14, 2024
c759a86
Update pennylane/capture/measure.py
albi3ro May 14, 2024
5cf9e47
writing docstrings [skip-ci]
albi3ro May 14, 2024
762dba5
adding testing and mcm support
albi3ro May 15, 2024
edcd62b
more testing
albi3ro May 16, 2024
e81e9a9
capture qnode into jaxpr
albi3ro May 17, 2024
1fca5dd
Merge branch 'master' into plxpr-capture-measurements
albi3ro May 17, 2024
c8203e5
more testing
albi3ro May 17, 2024
e1c200a
rename test file
albi3ro May 17, 2024
f003163
Merge branch 'master' into plxpr-capture-measurements
albi3ro May 17, 2024
18c1c3f
Merge branch 'master' into plxpr-capture-measurements
albi3ro May 21, 2024
f3c0923
documentation improvements
albi3ro May 21, 2024
c0e356e
Update doc/releases/changelog-dev.md
albi3ro May 21, 2024
fbed5d0
Update pennylane/capture/measure.py
albi3ro May 21, 2024
deadbe5
Apply suggestions from code review
albi3ro May 22, 2024
12a5e8e
more tests
albi3ro May 23, 2024
a35ff18
oops
albi3ro May 23, 2024
2629f9b
oops
albi3ro May 23, 2024
d00d4a7
more docs
albi3ro May 23, 2024
5bcecb1
merge measurements PR branch in
albi3ro May 23, 2024
f367291
add testing
albi3ro May 23, 2024
f825c9e
remove measure function
albi3ro May 23, 2024
5cdfa87
Merge branch 'master' into plxpr-capture-measurements
albi3ro May 23, 2024
899bec3
Update pennylane/capture/__init__.py
albi3ro May 23, 2024
310ebda
additional tests and minor fixes
albi3ro May 24, 2024
c67d86e
Apply suggestions from code review
albi3ro May 24, 2024
c1b3108
dynamic wires testing
albi3ro May 24, 2024
40f5cde
Merge branch 'master' into plxpr-capture-measurements
albi3ro May 24, 2024
b3dadce
responding to comments
albi3ro May 28, 2024
23f0677
merge in target branch
albi3ro May 28, 2024
087d89b
remove measure function
albi3ro May 28, 2024
c66d67b
minor fixes
albi3ro May 28, 2024
57c132b
adding more tests and some minor fixes
albi3ro May 28, 2024
2a3184b
Update pennylane/capture/capture_qnode.py
albi3ro May 28, 2024
7b9c284
code cov and trying to fix sphinx
albi3ro May 28, 2024
b6d1bd6
Apply suggestions from code review [skip-ci]
albi3ro May 29, 2024
3fd7a3c
testing improvements, doc build attempt
albi3ro May 29, 2024
47fdb63
Merge branch 'master' into capture-qnode
albi3ro May 30, 2024
fd3432c
trying to fix docstring
albi3ro May 30, 2024
08e6960
Update tests/capture/test_capture_qnode.py
albi3ro Jun 3, 2024
d71b366
Update pennylane/capture/capture_qnode.py
albi3ro Jun 3, 2024
cca8025
split out shot vector test
albi3ro Jun 3, 2024
aab3442
Merge branch 'master' into capture-qnode
albi3ro Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@
`m = measure(0); qml.sample(m)`.
[(#5673)](https://github.com/PennyLaneAI/pennylane/pull/5673)

* PennyLane operators can now automatically be captured as instructions in JAXPR. See the experimental
`capture` module for more information.
* PennyLane operators, measurements, and QNodes can now automatically be captured as instructions in JAXPR.
[(#5564)](https://github.com/PennyLaneAI/pennylane/pull/5564)
[(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511)
[(#5708)](https://github.com/PennyLaneAI/pennylane/pull/5708)

* The `decompose` transform has an `error` kwarg to specify the type of error that should be raised,
allowing error types to be more consistent with the context the `decompose` function is used in.
Expand Down
29 changes: 28 additions & 1 deletion pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@

This module is experimental and will change significantly in the future.

.. currentmodule:: pennylane.capture

.. autosummary::
:toctree: api

~disable
~enable
~enabled
~create_operator_primitive
~create_measurement_obs_primitive
~create_measurement_wires_primitive
~create_measurement_mcm_primitive
~qnode_call

To activate and deactivate the new PennyLane program capturing mechanism, use
the switches ``qml.capture.enable`` and ``qml.capture.disable``.
Expand Down Expand Up @@ -114,4 +127,18 @@ def _(*args, **kwargs):
"""
from .switches import disable, enable, enabled
from .capture_meta import CaptureMeta
from .primitives import create_operator_primitive
from .primitives import (
create_operator_primitive,
create_measurement_obs_primitive,
create_measurement_wires_primitive,
create_measurement_mcm_primitive,
)
from .capture_qnode import qnode_call


def __getattr__(key):
if key == "AbstractOperator":
from .primitives import _get_abstract_operator # pylint: disable=import-outside-toplevel

return _get_abstract_operator()
raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'")
33 changes: 33 additions & 0 deletions pennylane/capture/capture_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,39 @@ class CaptureMeta(type):

See ``pennylane/capture/explanations.md`` for more detailed information on how this technically
works.

.. code-block::

class AbstractMyObj(jax.core.AbstractValue):
pass

jax.core.raise_to_shaped_mappings[AbstractMyObj] = lambda aval, _: aval

class MyObj(metaclass=qml.capture.CaptureMeta):

primitive = jax.core.Primitive("MyObj")

@classmethod
def _primitive_bind_call(cls, a):
return cls.primitive.bind(a)

def __init__(self, a):
self.a = a

@MyObj.primitive.def_impl
def _(a):
return type.__call__(MyObj, a)

@MyObj.primitive.def_abstract_eval
def _(a):
return AbstractMyObj()

>>> jaxpr = jax.make_jaxpr(MyObj)(0.1)
>>> jaxpr
{ lambda ; a:f32[]. let b:AbstractMyObj() = MyObj a in (b,) }
>>> jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.1)
[<__main__.MyObj at 0x17fc3ea50>]

"""

@property
Expand Down
162 changes: 162 additions & 0 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This submodule defines a capture compatible call to QNodes.
"""

from functools import lru_cache, partial

import pennylane as qml

has_jax = True
try:
import jax
except ImportError:
has_jax = False


def _get_shapes_for(*measurements, shots=None, num_device_wires=0):
if jax.config.jax_enable_x64:
dtype_map = {
float: jax.numpy.float64,
int: jax.numpy.int64,
complex: jax.numpy.complex128,
}
else:
dtype_map = {
float: jax.numpy.float32,
int: jax.numpy.int32,
complex: jax.numpy.complex64,
}

shapes = []
if not shots:
shots = [None]

for s in shots:
for m in measurements:
shape, dtype = m.abstract_eval(shots=s, num_device_wires=num_device_wires)
shapes.append(jax.core.ShapedArray(shape, dtype_map.get(dtype, dtype)))
return shapes


@lru_cache()
def _get_qnode_prim():
if not has_jax:
return None
qnode_prim = jax.core.Primitive("qnode")
qnode_prim.multiple_results = True

@qnode_prim.def_impl
def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr):
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
def qfunc(*inner_args):
return jax.core.eval_jaxpr(qfunc_jaxpr.jaxpr, qfunc_jaxpr.consts, *inner_args)

qnode = qml.QNode(qfunc, device, **qnode_kwargs)
return qnode._impl_call(*args, shots=shots) # pylint: disable=protected-access

# pylint: disable=unused-argument
@qnode_prim.def_abstract_eval
def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr):
mps = qfunc_jaxpr.out_avals
return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))

return qnode_prim


# pylint: disable=protected-access
def _get_device_shots(device) -> "qml.measurements.Shots":
if isinstance(device, qml.devices.LegacyDevice):
if device._shot_vector:
return qml.measurements.Shots(device._raw_shot_sequence)
return qml.measurements.Shots(device.shots)
return device.shots


def qnode_call(qnode: "qml.QNode", *args, **kwargs) -> "qml.typing.Result":
"""A capture compatible call to a QNode. This function is internally used by ``QNode.__call__``.

Args:
qnode (QNode): a QNode
args: the arguments the QNode is called with

Keyword Args:
Any keyword arguments accepted by the quantum function
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

Returns:
qml.typing.Result: the result of a qnode execution

**Example:**

.. code-block:: python

@qml.qnode(qml.device('lightning.qubit', wires=1))
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.Z(0)), qml.probs()

def f(x):
expval_z, probs = circuit(np.pi * x, shots=50)
return 2*expval_z + probs

jaxpr = jax.make_jaxpr(f)(0.1)
print("jaxpr: \n", jaxpr)

res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.7)
print("\nresult: \n", res)


.. code-block:: none

jaxpr:
{ lambda ; a:f32[]. let
b:f32[] = mul 3.141592653589793 a
c:f32[] d:f32[2] = qnode[
device=<lightning.qubit device (wires=1) at 0x30755b3d0>
qfunc_jaxpr={ lambda ; e:f32[]. let
_:AbstractOperator() = RX[n_wires=1] e 0
f:AbstractOperator() = PauliZ[n_wires=1] 0
g:AbstractMeasurement(n_wires=None) = expval_obs f
h:AbstractMeasurement(n_wires=0) = probs_wires
in (g, h) }
qnode_kwargs={'diff_method': 'best', 'grad_on_execution': 'best', 'cache': False, 'cachesize': 10000, 'max_diff': 1, 'max_expansion': 10, 'device_vjp': False}
shots=Shots(total=50)
] b
i:f32[] = mul 2.0 c
j:f32[2] = add i d
in (j,) }

result:
[Array([-1.08, -0.64], dtype=float32)]


albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""
shots = kwargs.pop("shots", _get_device_shots(qnode.device))
shots = qml.measurements.Shots(shots)
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
if shots.has_partitioned_shots:
# Questions over the pytrees and the nested result object shape
raise NotImplementedError("shot vectors are not yet supported with plxpr capture.")

if not qnode.device.wires:
raise NotImplementedError("devices must specify wires for integration with plxpr capture.")

qfunc = partial(qnode.func, **kwargs) if kwargs else qnode.func

qfunc_jaxpr = jax.make_jaxpr(qfunc)(*args)
qnode_kwargs = {"diff_method": qnode.diff_method, **qnode.execute_kwargs}
qnode_prim = _get_qnode_prim()

return qnode_prim.bind(
*args, shots=shots, device=qnode.device, qnode_kwargs=qnode_kwargs, qfunc_jaxpr=qfunc_jaxpr
)
Loading
Loading