From 6055eee1fce5fc26daaba8d0ef798a6172ebcf35 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 20 Jun 2024 15:31:20 -0400 Subject: [PATCH 1/6] minor capture fixes for from_plxpr --- pennylane/capture/__init__.py | 11 ++++++++ pennylane/capture/capture_qnode.py | 14 ++++++++--- pennylane/measurements/sample.py | 2 +- tests/capture/test_capture_qnode.py | 1 + tests/capture/test_measurements_capture.py | 29 ++++++++++++++++------ 5 files changed, 44 insertions(+), 13 deletions(-) diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index c740d85700e..63b4aaa4261 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -141,4 +141,15 @@ def __getattr__(key): from .primitives import _get_abstract_operator # pylint: disable=import-outside-toplevel return _get_abstract_operator() + + if key == "AbstractMeasurement": + from .primitives import _get_abstract_measurement # pylint: disable=import-outside-toplevel + + return _get_abstract_measurement() + + if key == "qnode_prim": + from .capture_qnode import _get_qnode_prim # pylint: disable=import-outside-toplevel + + return _get_qnode_prim() + raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'") diff --git a/pennylane/capture/capture_qnode.py b/pennylane/capture/capture_qnode.py index 6f305f88507..5852a57a0f9 100644 --- a/pennylane/capture/capture_qnode.py +++ b/pennylane/capture/capture_qnode.py @@ -61,7 +61,7 @@ def _get_qnode_prim(): qnode_prim.multiple_results = True @qnode_prim.def_impl - def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr): + def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr): def qfunc(*inner_args): return jax.core.eval_jaxpr(qfunc_jaxpr.jaxpr, qfunc_jaxpr.consts, *inner_args) @@ -70,7 +70,7 @@ def qfunc(*inner_args): # pylint: disable=unused-argument @qnode_prim.def_abstract_eval - def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr): + def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr): mps = qfunc_jaxpr.out_avals return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires)) @@ -166,6 +166,12 @@ def f(x): qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config} qnode_prim = _get_qnode_prim() - return qnode_prim.bind( - *args, shots=shots, device=qnode.device, qnode_kwargs=qnode_kwargs, qfunc_jaxpr=qfunc_jaxpr + res = qnode_prim.bind( + *args, + shots=shots, + qnode=qnode, + device=qnode.device, + qnode_kwargs=qnode_kwargs, + qfunc_jaxpr=qfunc_jaxpr, ) + return res[0] if len(res) == 1 else res diff --git a/pennylane/measurements/sample.py b/pennylane/measurements/sample.py index da10c55e61c..a8333266c8a 100644 --- a/pennylane/measurements/sample.py +++ b/pennylane/measurements/sample.py @@ -134,7 +134,7 @@ def circuit(x): [0, 0]]) """ - return SampleMP(obs=op, wires=wires) + return SampleMP(obs=op, wires=None if wires is None else qml.wires.Wires(wires)) class SampleMP(SampleMeasurement): diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index ca8f63ffed2..56f9f4739f0 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -119,6 +119,7 @@ def circuit(x): assert jaxpr.out_avals[0] == jax.core.ShapedArray((), fdtype) assert eqn0.params["device"] == dev + assert eqn0.params["qnode"] == circuit assert eqn0.params["shots"] == qml.measurements.Shots(None) expected_kwargs = {"diff_method": "best"} expected_kwargs.update(circuit.execute_kwargs) diff --git a/tests/capture/test_measurements_capture.py b/tests/capture/test_measurements_capture.py index b516e0b30b9..383c8a97bb5 100644 --- a/tests/capture/test_measurements_capture.py +++ b/tests/capture/test_measurements_capture.py @@ -453,17 +453,25 @@ def f(c1, c2): @pytest.mark.parametrize("x64_mode", (True, False)) class TestSample: - @pytest.mark.parametrize("wires, dim1_len", [([0, 1, 2], 3), ([], 4)]) + @pytest.mark.parametrize("wires, dim1_len", [([0, 1, 2], 3), ([], 4), (1, 1)]) def test_wires(self, wires, dim1_len, x64_mode): """Tests capturing samples on wires.""" initial_mode = jax.config.jax_enable_x64 jax.config.update("jax_enable_x64", x64_mode) - def f(*inner_wires): - return qml.sample(wires=inner_wires) + if isinstance(wires, list): - jaxpr = jax.make_jaxpr(f)(*wires) + def f(*inner_wires): + return qml.sample(wires=inner_wires) + + jaxpr = jax.make_jaxpr(f)(*wires) + else: + + def f(inner_wire): + return qml.sample(wires=inner_wire) + + jaxpr = jax.make_jaxpr(f)(wires) assert len(jaxpr.eqns) == 1 @@ -471,15 +479,20 @@ def f(*inner_wires): assert [x.aval for x in jaxpr.eqns[0].invars] == jaxpr.in_avals mp = jaxpr.eqns[0].outvars[0].aval assert isinstance(mp, AbstractMeasurement) - assert mp.n_wires == len(wires) + assert mp.n_wires == len(wires) if isinstance(wires, list) else 1 assert mp._abstract_eval == SampleMP._abstract_eval shapes = _get_shapes_for( *jaxpr.out_avals, shots=qml.measurements.Shots(50), num_device_wires=4 ) - assert shapes[0] == jax.core.ShapedArray( - (50, dim1_len), jax.numpy.int64 if x64_mode else jax.numpy.int32 - ) + if isinstance(wires, list): + assert shapes[0] == jax.core.ShapedArray( + (50, dim1_len), jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) + else: + assert shapes[0] == jax.core.ShapedArray( + (50,), jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) with pytest.raises(ValueError, match="finite shots are required"): jaxpr.out_avals[0].abstract_eval(shots=None, num_device_wires=4) From 98dee71ccf2e422abb1bc5e9783a90fdac4a5f5b Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 20 Jun 2024 15:51:50 -0400 Subject: [PATCH 2/6] changelog --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index eb52d55607d..5b292f9de9e 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -277,6 +277,7 @@ [(#5708)](https://github.com/PennyLaneAI/pennylane/pull/5708) [(#5523)](https://github.com/PennyLaneAI/pennylane/pull/5523) [(#5686)](https://github.com/PennyLaneAI/pennylane/pull/5686) + [(#5889)](https://github.com/PennyLaneAI/pennylane/pull/5889) * The `decompose` transform has an `error` kwarg to specify the type of error that should be raised, allowing error types to be more consistent with the context the `decompose` function is used in. From 7d811a6a532f0d575f1ae5a31cb4d915177fc394 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Thu, 20 Jun 2024 15:52:05 -0400 Subject: [PATCH 3/6] Apply suggestions from code review Co-authored-by: David Wierichs --- tests/capture/test_measurements_capture.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/capture/test_measurements_capture.py b/tests/capture/test_measurements_capture.py index 383c8a97bb5..a653a3146f9 100644 --- a/tests/capture/test_measurements_capture.py +++ b/tests/capture/test_measurements_capture.py @@ -485,14 +485,11 @@ def f(inner_wire): shapes = _get_shapes_for( *jaxpr.out_avals, shots=qml.measurements.Shots(50), num_device_wires=4 ) - if isinstance(wires, list): - assert shapes[0] == jax.core.ShapedArray( - (50, dim1_len), jax.numpy.int64 if x64_mode else jax.numpy.int32 - ) - else: - assert shapes[0] == jax.core.ShapedArray( - (50,), jax.numpy.int64 if x64_mode else jax.numpy.int32 - ) + assert len(shapes) == 1 + shape = (50, dim1_len) if isinstance(wires, list) else (50,) + assert shapes[0] == jax.core.ShapedArray( + shape, jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) with pytest.raises(ValueError, match="finite shots are required"): jaxpr.out_avals[0].abstract_eval(shots=None, num_device_wires=4) From b585f4a6c1e9a0e3bc3efe24ec7bf09822120372 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 20 Jun 2024 18:12:34 -0400 Subject: [PATCH 4/6] fix test coverage --- pennylane/capture/__init__.py | 25 ++++++++++++++++++++++ tests/capture/test_capture_qnode.py | 4 +--- tests/capture/test_measurements_capture.py | 4 +--- tests/capture/test_operators.py | 4 +--- tests/capture/test_templates.py | 2 -- 5 files changed, 28 insertions(+), 11 deletions(-) diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 63b4aaa4261..75e82da4053 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -135,6 +135,13 @@ def _(*args, **kwargs): ) from .capture_qnode import qnode_call +# by defining this here, we avoid +# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module) +# on use of from capture import AbstractOperator +AbstractOperator: type +AbstractMeasurement: type +qnode_prim: "jax.core.Primitive" + def __getattr__(key): if key == "AbstractOperator": @@ -153,3 +160,21 @@ def __getattr__(key): return _get_qnode_prim() raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'") + + +# pylint: disable=undefined-all-variable +__all__ = ( + "disable", + "enable", + "enabled", + "CaptureMeta", + "ABCCaptureMeta", + "create_operator_primitive", + "create_measurement_obs_primitive", + "create_measurement_wires_primitive", + "create_measurement_mcm_primitive", + "qnode_call", + "AbstractOperator", + "AbstractMeasurement", + "qnode_prim", +) diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 56f9f4739f0..553c4485b73 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -21,9 +21,7 @@ import pytest import pennylane as qml -from pennylane.capture.capture_qnode import _get_qnode_prim - -qnode_prim = _get_qnode_prim() +from pennylane.capture import qnode_prim pytestmark = pytest.mark.jax diff --git a/tests/capture/test_measurements_capture.py b/tests/capture/test_measurements_capture.py index a653a3146f9..eb0ddc5bf6f 100644 --- a/tests/capture/test_measurements_capture.py +++ b/tests/capture/test_measurements_capture.py @@ -20,7 +20,7 @@ import pytest import pennylane as qml -from pennylane.capture.primitives import _get_abstract_measurement +from pennylane.capture import AbstractMeasurement from pennylane.measurements import ( ClassicalShadowMP, DensityMatrixMP, @@ -40,8 +40,6 @@ pytestmark = pytest.mark.jax -AbstractMeasurement = _get_abstract_measurement() - @pytest.fixture(autouse=True) def enable_disable_plxpr(): diff --git a/tests/capture/test_operators.py b/tests/capture/test_operators.py index 6388a3b6ea4..3caff3e42fb 100644 --- a/tests/capture/test_operators.py +++ b/tests/capture/test_operators.py @@ -18,14 +18,12 @@ import pytest import pennylane as qml -from pennylane.capture.primitives import _get_abstract_operator +from pennylane.capture import AbstractOperator jax = pytest.importorskip("jax") pytestmark = pytest.mark.jax -AbstractOperator = _get_abstract_operator() - @pytest.fixture(autouse=True) def enable_disable_plxpr(): diff --git a/tests/capture/test_templates.py b/tests/capture/test_templates.py index 7648a4c1614..be54b2e8cb8 100644 --- a/tests/capture/test_templates.py +++ b/tests/capture/test_templates.py @@ -23,14 +23,12 @@ import pytest import pennylane as qml -from pennylane.capture.primitives import _get_abstract_operator jax = pytest.importorskip("jax") jnp = jax.numpy pytestmark = pytest.mark.jax -AbstractOperator = _get_abstract_operator() original_op_bind_code = qml.operation.Operator._primitive_bind_call.__code__ From e733491d1456396cc60081910090a930b1cc2026 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 20 Jun 2024 18:15:01 -0400 Subject: [PATCH 5/6] fix test coverage --- pennylane/capture/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 75e82da4053..98637c2d425 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -162,7 +162,6 @@ def __getattr__(key): raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'") -# pylint: disable=undefined-all-variable __all__ = ( "disable", "enable", From 3f4c8e0af23488fee7379f7dc42365e594bc9642 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 21 Jun 2024 09:46:26 -0400 Subject: [PATCH 6/6] change import order --- tests/capture/test_measurements_capture.py | 3 ++- tests/capture/test_operators.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/capture/test_measurements_capture.py b/tests/capture/test_measurements_capture.py index eb0ddc5bf6f..a3bce757266 100644 --- a/tests/capture/test_measurements_capture.py +++ b/tests/capture/test_measurements_capture.py @@ -20,7 +20,6 @@ import pytest import pennylane as qml -from pennylane.capture import AbstractMeasurement from pennylane.measurements import ( ClassicalShadowMP, DensityMatrixMP, @@ -38,6 +37,8 @@ jax = pytest.importorskip("jax") +from pennylane.capture import AbstractMeasurement # pylint: disable=wrong-import-position + pytestmark = pytest.mark.jax diff --git a/tests/capture/test_operators.py b/tests/capture/test_operators.py index 3caff3e42fb..b5142dbb5e9 100644 --- a/tests/capture/test_operators.py +++ b/tests/capture/test_operators.py @@ -18,10 +18,11 @@ import pytest import pennylane as qml -from pennylane.capture import AbstractOperator jax = pytest.importorskip("jax") +from pennylane.capture import AbstractOperator # pylint: disable=wrong-import-position + pytestmark = pytest.mark.jax