From 6055eee1fce5fc26daaba8d0ef798a6172ebcf35 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 20 Jun 2024 15:31:20 -0400 Subject: [PATCH] minor capture fixes for from_plxpr --- pennylane/capture/__init__.py | 11 ++++++++ pennylane/capture/capture_qnode.py | 14 ++++++++--- pennylane/measurements/sample.py | 2 +- tests/capture/test_capture_qnode.py | 1 + tests/capture/test_measurements_capture.py | 29 ++++++++++++++++------ 5 files changed, 44 insertions(+), 13 deletions(-) diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index c740d85700e..63b4aaa4261 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -141,4 +141,15 @@ def __getattr__(key): from .primitives import _get_abstract_operator # pylint: disable=import-outside-toplevel return _get_abstract_operator() + + if key == "AbstractMeasurement": + from .primitives import _get_abstract_measurement # pylint: disable=import-outside-toplevel + + return _get_abstract_measurement() + + if key == "qnode_prim": + from .capture_qnode import _get_qnode_prim # pylint: disable=import-outside-toplevel + + return _get_qnode_prim() + raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'") diff --git a/pennylane/capture/capture_qnode.py b/pennylane/capture/capture_qnode.py index 6f305f88507..5852a57a0f9 100644 --- a/pennylane/capture/capture_qnode.py +++ b/pennylane/capture/capture_qnode.py @@ -61,7 +61,7 @@ def _get_qnode_prim(): qnode_prim.multiple_results = True @qnode_prim.def_impl - def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr): + def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr): def qfunc(*inner_args): return jax.core.eval_jaxpr(qfunc_jaxpr.jaxpr, qfunc_jaxpr.consts, *inner_args) @@ -70,7 +70,7 @@ def qfunc(*inner_args): # pylint: disable=unused-argument @qnode_prim.def_abstract_eval - def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr): + def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr): mps = qfunc_jaxpr.out_avals return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires)) @@ -166,6 +166,12 @@ def f(x): qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config} qnode_prim = _get_qnode_prim() - return qnode_prim.bind( - *args, shots=shots, device=qnode.device, qnode_kwargs=qnode_kwargs, qfunc_jaxpr=qfunc_jaxpr + res = qnode_prim.bind( + *args, + shots=shots, + qnode=qnode, + device=qnode.device, + qnode_kwargs=qnode_kwargs, + qfunc_jaxpr=qfunc_jaxpr, ) + return res[0] if len(res) == 1 else res diff --git a/pennylane/measurements/sample.py b/pennylane/measurements/sample.py index da10c55e61c..a8333266c8a 100644 --- a/pennylane/measurements/sample.py +++ b/pennylane/measurements/sample.py @@ -134,7 +134,7 @@ def circuit(x): [0, 0]]) """ - return SampleMP(obs=op, wires=wires) + return SampleMP(obs=op, wires=None if wires is None else qml.wires.Wires(wires)) class SampleMP(SampleMeasurement): diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index ca8f63ffed2..56f9f4739f0 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -119,6 +119,7 @@ def circuit(x): assert jaxpr.out_avals[0] == jax.core.ShapedArray((), fdtype) assert eqn0.params["device"] == dev + assert eqn0.params["qnode"] == circuit assert eqn0.params["shots"] == qml.measurements.Shots(None) expected_kwargs = {"diff_method": "best"} expected_kwargs.update(circuit.execute_kwargs) diff --git a/tests/capture/test_measurements_capture.py b/tests/capture/test_measurements_capture.py index b516e0b30b9..383c8a97bb5 100644 --- a/tests/capture/test_measurements_capture.py +++ b/tests/capture/test_measurements_capture.py @@ -453,17 +453,25 @@ def f(c1, c2): @pytest.mark.parametrize("x64_mode", (True, False)) class TestSample: - @pytest.mark.parametrize("wires, dim1_len", [([0, 1, 2], 3), ([], 4)]) + @pytest.mark.parametrize("wires, dim1_len", [([0, 1, 2], 3), ([], 4), (1, 1)]) def test_wires(self, wires, dim1_len, x64_mode): """Tests capturing samples on wires.""" initial_mode = jax.config.jax_enable_x64 jax.config.update("jax_enable_x64", x64_mode) - def f(*inner_wires): - return qml.sample(wires=inner_wires) + if isinstance(wires, list): - jaxpr = jax.make_jaxpr(f)(*wires) + def f(*inner_wires): + return qml.sample(wires=inner_wires) + + jaxpr = jax.make_jaxpr(f)(*wires) + else: + + def f(inner_wire): + return qml.sample(wires=inner_wire) + + jaxpr = jax.make_jaxpr(f)(wires) assert len(jaxpr.eqns) == 1 @@ -471,15 +479,20 @@ def f(*inner_wires): assert [x.aval for x in jaxpr.eqns[0].invars] == jaxpr.in_avals mp = jaxpr.eqns[0].outvars[0].aval assert isinstance(mp, AbstractMeasurement) - assert mp.n_wires == len(wires) + assert mp.n_wires == len(wires) if isinstance(wires, list) else 1 assert mp._abstract_eval == SampleMP._abstract_eval shapes = _get_shapes_for( *jaxpr.out_avals, shots=qml.measurements.Shots(50), num_device_wires=4 ) - assert shapes[0] == jax.core.ShapedArray( - (50, dim1_len), jax.numpy.int64 if x64_mode else jax.numpy.int32 - ) + if isinstance(wires, list): + assert shapes[0] == jax.core.ShapedArray( + (50, dim1_len), jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) + else: + assert shapes[0] == jax.core.ShapedArray( + (50,), jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) with pytest.raises(ValueError, match="finite shots are required"): jaxpr.out_avals[0].abstract_eval(shots=None, num_device_wires=4)