Skip to content

Commit

Permalink
Minor program capture fixes (#5889)
Browse files Browse the repository at this point in the history
**Context:**

Catalyst PR #837 (PennyLaneAI/catalyst#837)
needs a couple minor updates to the capture module.

**Description of the Change:**

1) makes it possible to do `from pennylane.capture import
AbstractOperator, AbstractMeasurement, qnode_prim` so we don't have to
touch private functions

2) Adds `qnode` as a keyword argument that gets bound to the qnode
primitive

3) Makes it so we can capture a sample measurement specified like
`qml.sample(wires=1)`

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**
[sc-66703]

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
  • Loading branch information
albi3ro and dwierichs authored Jun 21, 2024
1 parent 21b40c5 commit 96a3d56
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 22 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@
[(#5708)](https://github.com/PennyLaneAI/pennylane/pull/5708)
[(#5523)](https://github.com/PennyLaneAI/pennylane/pull/5523)
[(#5686)](https://github.com/PennyLaneAI/pennylane/pull/5686)
[(#5889)](https://github.com/PennyLaneAI/pennylane/pull/5889)

* The `decompose` transform has an `error` kwarg to specify the type of error that should be raised,
allowing error types to be more consistent with the context the `decompose` function is used in.
Expand Down
35 changes: 35 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,45 @@ def _(*args, **kwargs):
)
from .capture_qnode import qnode_call

# by defining this here, we avoid
# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module)
# on use of from capture import AbstractOperator
AbstractOperator: type
AbstractMeasurement: type
qnode_prim: "jax.core.Primitive"


def __getattr__(key):
if key == "AbstractOperator":
from .primitives import _get_abstract_operator # pylint: disable=import-outside-toplevel

return _get_abstract_operator()

if key == "AbstractMeasurement":
from .primitives import _get_abstract_measurement # pylint: disable=import-outside-toplevel

return _get_abstract_measurement()

if key == "qnode_prim":
from .capture_qnode import _get_qnode_prim # pylint: disable=import-outside-toplevel

return _get_qnode_prim()

raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'")


__all__ = (
"disable",
"enable",
"enabled",
"CaptureMeta",
"ABCCaptureMeta",
"create_operator_primitive",
"create_measurement_obs_primitive",
"create_measurement_wires_primitive",
"create_measurement_mcm_primitive",
"qnode_call",
"AbstractOperator",
"AbstractMeasurement",
"qnode_prim",
)
14 changes: 10 additions & 4 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _get_qnode_prim():
qnode_prim.multiple_results = True

@qnode_prim.def_impl
def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr):
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr):
def qfunc(*inner_args):
return jax.core.eval_jaxpr(qfunc_jaxpr.jaxpr, qfunc_jaxpr.consts, *inner_args)

Expand All @@ -70,7 +70,7 @@ def qfunc(*inner_args):

# pylint: disable=unused-argument
@qnode_prim.def_abstract_eval
def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr):
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr):
mps = qfunc_jaxpr.out_avals
return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))

Expand Down Expand Up @@ -166,6 +166,12 @@ def f(x):
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
qnode_prim = _get_qnode_prim()

return qnode_prim.bind(
*args, shots=shots, device=qnode.device, qnode_kwargs=qnode_kwargs, qfunc_jaxpr=qfunc_jaxpr
res = qnode_prim.bind(
*args,
shots=shots,
qnode=qnode,
device=qnode.device,
qnode_kwargs=qnode_kwargs,
qfunc_jaxpr=qfunc_jaxpr,
)
return res[0] if len(res) == 1 else res
2 changes: 1 addition & 1 deletion pennylane/measurements/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def circuit(x):
[0, 0]])
"""
return SampleMP(obs=op, wires=wires)
return SampleMP(obs=op, wires=None if wires is None else qml.wires.Wires(wires))


class SampleMP(SampleMeasurement):
Expand Down
5 changes: 2 additions & 3 deletions tests/capture/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import pytest

import pennylane as qml
from pennylane.capture.capture_qnode import _get_qnode_prim

qnode_prim = _get_qnode_prim()
from pennylane.capture import qnode_prim

pytestmark = pytest.mark.jax

Expand Down Expand Up @@ -119,6 +117,7 @@ def circuit(x):
assert jaxpr.out_avals[0] == jax.core.ShapedArray((), fdtype)

assert eqn0.params["device"] == dev
assert eqn0.params["qnode"] == circuit
assert eqn0.params["shots"] == qml.measurements.Shots(None)
expected_kwargs = {"diff_method": "best"}
expected_kwargs.update(circuit.execute_kwargs)
Expand Down
27 changes: 18 additions & 9 deletions tests/capture/test_measurements_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import pytest

import pennylane as qml
from pennylane.capture.primitives import _get_abstract_measurement
from pennylane.measurements import (
ClassicalShadowMP,
DensityMatrixMP,
Expand All @@ -38,9 +37,9 @@

jax = pytest.importorskip("jax")

pytestmark = pytest.mark.jax
from pennylane.capture import AbstractMeasurement # pylint: disable=wrong-import-position

AbstractMeasurement = _get_abstract_measurement()
pytestmark = pytest.mark.jax


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -453,32 +452,42 @@ def f(c1, c2):
@pytest.mark.parametrize("x64_mode", (True, False))
class TestSample:

@pytest.mark.parametrize("wires, dim1_len", [([0, 1, 2], 3), ([], 4)])
@pytest.mark.parametrize("wires, dim1_len", [([0, 1, 2], 3), ([], 4), (1, 1)])
def test_wires(self, wires, dim1_len, x64_mode):
"""Tests capturing samples on wires."""

initial_mode = jax.config.jax_enable_x64
jax.config.update("jax_enable_x64", x64_mode)

def f(*inner_wires):
return qml.sample(wires=inner_wires)
if isinstance(wires, list):

def f(*inner_wires):
return qml.sample(wires=inner_wires)

jaxpr = jax.make_jaxpr(f)(*wires)
jaxpr = jax.make_jaxpr(f)(*wires)
else:

def f(inner_wire):
return qml.sample(wires=inner_wire)

jaxpr = jax.make_jaxpr(f)(wires)

assert len(jaxpr.eqns) == 1

assert jaxpr.eqns[0].primitive == SampleMP._wires_primitive
assert [x.aval for x in jaxpr.eqns[0].invars] == jaxpr.in_avals
mp = jaxpr.eqns[0].outvars[0].aval
assert isinstance(mp, AbstractMeasurement)
assert mp.n_wires == len(wires)
assert mp.n_wires == len(wires) if isinstance(wires, list) else 1
assert mp._abstract_eval == SampleMP._abstract_eval

shapes = _get_shapes_for(
*jaxpr.out_avals, shots=qml.measurements.Shots(50), num_device_wires=4
)
assert len(shapes) == 1
shape = (50, dim1_len) if isinstance(wires, list) else (50,)
assert shapes[0] == jax.core.ShapedArray(
(50, dim1_len), jax.numpy.int64 if x64_mode else jax.numpy.int32
shape, jax.numpy.int64 if x64_mode else jax.numpy.int32
)

with pytest.raises(ValueError, match="finite shots are required"):
Expand Down
5 changes: 2 additions & 3 deletions tests/capture/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
import pytest

import pennylane as qml
from pennylane.capture.primitives import _get_abstract_operator

jax = pytest.importorskip("jax")

pytestmark = pytest.mark.jax
from pennylane.capture import AbstractOperator # pylint: disable=wrong-import-position

AbstractOperator = _get_abstract_operator()
pytestmark = pytest.mark.jax


@pytest.fixture(autouse=True)
Expand Down
2 changes: 0 additions & 2 deletions tests/capture/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@
import pytest

import pennylane as qml
from pennylane.capture.primitives import _get_abstract_operator

jax = pytest.importorskip("jax")
jnp = jax.numpy

pytestmark = pytest.mark.jax

AbstractOperator = _get_abstract_operator()
original_op_bind_code = qml.operation.Operator._primitive_bind_call.__code__


Expand Down

0 comments on commit 96a3d56

Please sign in to comment.