From 5305961fd38058a725a59427af0991b58996200b Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Fri, 21 Jun 2024 10:54:04 -0400 Subject: [PATCH] Minor program capture fixes (#5889) **Context:** Catalyst PR #837 (https://github.com/PennyLaneAI/catalyst/pull/837) needs a couple minor updates to the capture module. **Description of the Change:** 1) makes it possible to do `from pennylane.capture import AbstractOperator, AbstractMeasurement, qnode_prim` so we don't have to touch private functions 2) Adds `qnode` as a keyword argument that gets bound to the qnode primitive 3) Makes it so we can capture a sample measurement specified like `qml.sample(wires=1)` **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-66703] --------- Co-authored-by: David Wierichs --- doc/releases/changelog-dev.md | 1 + pennylane/capture/__init__.py | 35 ++++++++++++++++++++++ pennylane/capture/capture_qnode.py | 14 ++++++--- pennylane/measurements/sample.py | 2 +- tests/capture/test_capture_qnode.py | 5 ++-- tests/capture/test_measurements_capture.py | 27 +++++++++++------ tests/capture/test_operators.py | 5 ++-- tests/capture/test_templates.py | 2 -- 8 files changed, 69 insertions(+), 22 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index d2478262827..c2daebb11f8 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -282,6 +282,7 @@ [(#5708)](https://github.com/PennyLaneAI/pennylane/pull/5708) [(#5523)](https://github.com/PennyLaneAI/pennylane/pull/5523) [(#5686)](https://github.com/PennyLaneAI/pennylane/pull/5686) + [(#5889)](https://github.com/PennyLaneAI/pennylane/pull/5889) * The `decompose` transform has an `error` kwarg to specify the type of error that should be raised, allowing error types to be more consistent with the context the `decompose` function is used in. diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index c740d85700e..98637c2d425 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -135,10 +135,45 @@ def _(*args, **kwargs): ) from .capture_qnode import qnode_call +# by defining this here, we avoid +# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module) +# on use of from capture import AbstractOperator +AbstractOperator: type +AbstractMeasurement: type +qnode_prim: "jax.core.Primitive" + def __getattr__(key): if key == "AbstractOperator": from .primitives import _get_abstract_operator # pylint: disable=import-outside-toplevel return _get_abstract_operator() + + if key == "AbstractMeasurement": + from .primitives import _get_abstract_measurement # pylint: disable=import-outside-toplevel + + return _get_abstract_measurement() + + if key == "qnode_prim": + from .capture_qnode import _get_qnode_prim # pylint: disable=import-outside-toplevel + + return _get_qnode_prim() + raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'") + + +__all__ = ( + "disable", + "enable", + "enabled", + "CaptureMeta", + "ABCCaptureMeta", + "create_operator_primitive", + "create_measurement_obs_primitive", + "create_measurement_wires_primitive", + "create_measurement_mcm_primitive", + "qnode_call", + "AbstractOperator", + "AbstractMeasurement", + "qnode_prim", +) diff --git a/pennylane/capture/capture_qnode.py b/pennylane/capture/capture_qnode.py index 6f305f88507..5852a57a0f9 100644 --- a/pennylane/capture/capture_qnode.py +++ b/pennylane/capture/capture_qnode.py @@ -61,7 +61,7 @@ def _get_qnode_prim(): qnode_prim.multiple_results = True @qnode_prim.def_impl - def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr): + def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr): def qfunc(*inner_args): return jax.core.eval_jaxpr(qfunc_jaxpr.jaxpr, qfunc_jaxpr.consts, *inner_args) @@ -70,7 +70,7 @@ def qfunc(*inner_args): # pylint: disable=unused-argument @qnode_prim.def_abstract_eval - def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr): + def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr): mps = qfunc_jaxpr.out_avals return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires)) @@ -166,6 +166,12 @@ def f(x): qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config} qnode_prim = _get_qnode_prim() - return qnode_prim.bind( - *args, shots=shots, device=qnode.device, qnode_kwargs=qnode_kwargs, qfunc_jaxpr=qfunc_jaxpr + res = qnode_prim.bind( + *args, + shots=shots, + qnode=qnode, + device=qnode.device, + qnode_kwargs=qnode_kwargs, + qfunc_jaxpr=qfunc_jaxpr, ) + return res[0] if len(res) == 1 else res diff --git a/pennylane/measurements/sample.py b/pennylane/measurements/sample.py index da10c55e61c..a8333266c8a 100644 --- a/pennylane/measurements/sample.py +++ b/pennylane/measurements/sample.py @@ -134,7 +134,7 @@ def circuit(x): [0, 0]]) """ - return SampleMP(obs=op, wires=wires) + return SampleMP(obs=op, wires=None if wires is None else qml.wires.Wires(wires)) class SampleMP(SampleMeasurement): diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index ca8f63ffed2..553c4485b73 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -21,9 +21,7 @@ import pytest import pennylane as qml -from pennylane.capture.capture_qnode import _get_qnode_prim - -qnode_prim = _get_qnode_prim() +from pennylane.capture import qnode_prim pytestmark = pytest.mark.jax @@ -119,6 +117,7 @@ def circuit(x): assert jaxpr.out_avals[0] == jax.core.ShapedArray((), fdtype) assert eqn0.params["device"] == dev + assert eqn0.params["qnode"] == circuit assert eqn0.params["shots"] == qml.measurements.Shots(None) expected_kwargs = {"diff_method": "best"} expected_kwargs.update(circuit.execute_kwargs) diff --git a/tests/capture/test_measurements_capture.py b/tests/capture/test_measurements_capture.py index b516e0b30b9..a3bce757266 100644 --- a/tests/capture/test_measurements_capture.py +++ b/tests/capture/test_measurements_capture.py @@ -20,7 +20,6 @@ import pytest import pennylane as qml -from pennylane.capture.primitives import _get_abstract_measurement from pennylane.measurements import ( ClassicalShadowMP, DensityMatrixMP, @@ -38,9 +37,9 @@ jax = pytest.importorskip("jax") -pytestmark = pytest.mark.jax +from pennylane.capture import AbstractMeasurement # pylint: disable=wrong-import-position -AbstractMeasurement = _get_abstract_measurement() +pytestmark = pytest.mark.jax @pytest.fixture(autouse=True) @@ -453,17 +452,25 @@ def f(c1, c2): @pytest.mark.parametrize("x64_mode", (True, False)) class TestSample: - @pytest.mark.parametrize("wires, dim1_len", [([0, 1, 2], 3), ([], 4)]) + @pytest.mark.parametrize("wires, dim1_len", [([0, 1, 2], 3), ([], 4), (1, 1)]) def test_wires(self, wires, dim1_len, x64_mode): """Tests capturing samples on wires.""" initial_mode = jax.config.jax_enable_x64 jax.config.update("jax_enable_x64", x64_mode) - def f(*inner_wires): - return qml.sample(wires=inner_wires) + if isinstance(wires, list): + + def f(*inner_wires): + return qml.sample(wires=inner_wires) - jaxpr = jax.make_jaxpr(f)(*wires) + jaxpr = jax.make_jaxpr(f)(*wires) + else: + + def f(inner_wire): + return qml.sample(wires=inner_wire) + + jaxpr = jax.make_jaxpr(f)(wires) assert len(jaxpr.eqns) == 1 @@ -471,14 +478,16 @@ def f(*inner_wires): assert [x.aval for x in jaxpr.eqns[0].invars] == jaxpr.in_avals mp = jaxpr.eqns[0].outvars[0].aval assert isinstance(mp, AbstractMeasurement) - assert mp.n_wires == len(wires) + assert mp.n_wires == len(wires) if isinstance(wires, list) else 1 assert mp._abstract_eval == SampleMP._abstract_eval shapes = _get_shapes_for( *jaxpr.out_avals, shots=qml.measurements.Shots(50), num_device_wires=4 ) + assert len(shapes) == 1 + shape = (50, dim1_len) if isinstance(wires, list) else (50,) assert shapes[0] == jax.core.ShapedArray( - (50, dim1_len), jax.numpy.int64 if x64_mode else jax.numpy.int32 + shape, jax.numpy.int64 if x64_mode else jax.numpy.int32 ) with pytest.raises(ValueError, match="finite shots are required"): diff --git a/tests/capture/test_operators.py b/tests/capture/test_operators.py index 6388a3b6ea4..b5142dbb5e9 100644 --- a/tests/capture/test_operators.py +++ b/tests/capture/test_operators.py @@ -18,13 +18,12 @@ import pytest import pennylane as qml -from pennylane.capture.primitives import _get_abstract_operator jax = pytest.importorskip("jax") -pytestmark = pytest.mark.jax +from pennylane.capture import AbstractOperator # pylint: disable=wrong-import-position -AbstractOperator = _get_abstract_operator() +pytestmark = pytest.mark.jax @pytest.fixture(autouse=True) diff --git a/tests/capture/test_templates.py b/tests/capture/test_templates.py index 7648a4c1614..be54b2e8cb8 100644 --- a/tests/capture/test_templates.py +++ b/tests/capture/test_templates.py @@ -23,14 +23,12 @@ import pytest import pennylane as qml -from pennylane.capture.primitives import _get_abstract_operator jax = pytest.importorskip("jax") jnp = jax.numpy pytestmark = pytest.mark.jax -AbstractOperator = _get_abstract_operator() original_op_bind_code = qml.operation.Operator._primitive_bind_call.__code__