Skip to content

Commit

Permalink
minor capture fixes for from_plxpr
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro committed Jun 20, 2024
1 parent 14a0b63 commit 6055eee
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 13 deletions.
11 changes: 11 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,15 @@ def __getattr__(key):
from .primitives import _get_abstract_operator # pylint: disable=import-outside-toplevel

return _get_abstract_operator()

if key == "AbstractMeasurement":
from .primitives import _get_abstract_measurement # pylint: disable=import-outside-toplevel

return _get_abstract_measurement()

if key == "qnode_prim":
from .capture_qnode import _get_qnode_prim # pylint: disable=import-outside-toplevel

return _get_qnode_prim()

raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'")
14 changes: 10 additions & 4 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _get_qnode_prim():
qnode_prim.multiple_results = True

@qnode_prim.def_impl
def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr):
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr):
def qfunc(*inner_args):
return jax.core.eval_jaxpr(qfunc_jaxpr.jaxpr, qfunc_jaxpr.consts, *inner_args)

Expand All @@ -70,7 +70,7 @@ def qfunc(*inner_args):

# pylint: disable=unused-argument
@qnode_prim.def_abstract_eval
def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr):
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr):
mps = qfunc_jaxpr.out_avals
return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))

Expand Down Expand Up @@ -166,6 +166,12 @@ def f(x):
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
qnode_prim = _get_qnode_prim()

return qnode_prim.bind(
*args, shots=shots, device=qnode.device, qnode_kwargs=qnode_kwargs, qfunc_jaxpr=qfunc_jaxpr
res = qnode_prim.bind(
*args,
shots=shots,
qnode=qnode,
device=qnode.device,
qnode_kwargs=qnode_kwargs,
qfunc_jaxpr=qfunc_jaxpr,
)
return res[0] if len(res) == 1 else res
2 changes: 1 addition & 1 deletion pennylane/measurements/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def circuit(x):
[0, 0]])
"""
return SampleMP(obs=op, wires=wires)
return SampleMP(obs=op, wires=None if wires is None else qml.wires.Wires(wires))


class SampleMP(SampleMeasurement):
Expand Down
1 change: 1 addition & 0 deletions tests/capture/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def circuit(x):
assert jaxpr.out_avals[0] == jax.core.ShapedArray((), fdtype)

assert eqn0.params["device"] == dev
assert eqn0.params["qnode"] == circuit
assert eqn0.params["shots"] == qml.measurements.Shots(None)
expected_kwargs = {"diff_method": "best"}
expected_kwargs.update(circuit.execute_kwargs)
Expand Down
29 changes: 21 additions & 8 deletions tests/capture/test_measurements_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,33 +453,46 @@ def f(c1, c2):
@pytest.mark.parametrize("x64_mode", (True, False))
class TestSample:

@pytest.mark.parametrize("wires, dim1_len", [([0, 1, 2], 3), ([], 4)])
@pytest.mark.parametrize("wires, dim1_len", [([0, 1, 2], 3), ([], 4), (1, 1)])
def test_wires(self, wires, dim1_len, x64_mode):
"""Tests capturing samples on wires."""

initial_mode = jax.config.jax_enable_x64
jax.config.update("jax_enable_x64", x64_mode)

def f(*inner_wires):
return qml.sample(wires=inner_wires)
if isinstance(wires, list):

jaxpr = jax.make_jaxpr(f)(*wires)
def f(*inner_wires):
return qml.sample(wires=inner_wires)

jaxpr = jax.make_jaxpr(f)(*wires)
else:

def f(inner_wire):
return qml.sample(wires=inner_wire)

jaxpr = jax.make_jaxpr(f)(wires)

assert len(jaxpr.eqns) == 1

assert jaxpr.eqns[0].primitive == SampleMP._wires_primitive
assert [x.aval for x in jaxpr.eqns[0].invars] == jaxpr.in_avals
mp = jaxpr.eqns[0].outvars[0].aval
assert isinstance(mp, AbstractMeasurement)
assert mp.n_wires == len(wires)
assert mp.n_wires == len(wires) if isinstance(wires, list) else 1
assert mp._abstract_eval == SampleMP._abstract_eval

shapes = _get_shapes_for(
*jaxpr.out_avals, shots=qml.measurements.Shots(50), num_device_wires=4
)
assert shapes[0] == jax.core.ShapedArray(
(50, dim1_len), jax.numpy.int64 if x64_mode else jax.numpy.int32
)
if isinstance(wires, list):
assert shapes[0] == jax.core.ShapedArray(
(50, dim1_len), jax.numpy.int64 if x64_mode else jax.numpy.int32
)
else:
assert shapes[0] == jax.core.ShapedArray(
(50,), jax.numpy.int64 if x64_mode else jax.numpy.int32
)

with pytest.raises(ValueError, match="finite shots are required"):
jaxpr.out_avals[0].abstract_eval(shots=None, num_device_wires=4)
Expand Down

0 comments on commit 6055eee

Please sign in to comment.