Skip to content

Commit

Permalink
Allow abstract Boolean measurement results as a parameter to qml.samp…
Browse files Browse the repository at this point in the history
…le (#5673)

This PR allows `qml.sample` (and implicitly other measurement processes)
to accept a Boolean tracer in the observable parameter, similar to what
is already done for the `MeasurementValue` value object. The `wires`
attribute will be empty in this case.

[sc-62096]
  • Loading branch information
dime10 committed May 24, 2024
1 parent 9c9f650 commit fbc2a39
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 1 deletion.
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@
* Sets up the framework for the development of an `assert_equal` function for testing operator comparison.
[(#5634)](https://github.com/PennyLaneAI/pennylane/pull/5634)

* `qml.sample` can now be used on Boolean values representing mid-circuit measurement results in
traced quantum functions. This feature is used with Catalyst to enable the pattern
`m = measure(0); qml.sample(m)`.
[(#5673)](https://github.com/PennyLaneAI/pennylane/pull/5673)

* PennyLane operators can now automatically be captured as instructions in JAXPR. See the experimental
`capture` module for more information.
[(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511)
Expand Down Expand Up @@ -202,6 +207,7 @@ Ahmed Darwish,
Isaac De Vlugt,
Pietropaolo Frisoni,
Emiliano Godinez,
David Ittah,
Soran Jahangiri,
Korbinian Kottmann,
Christina Lee,
Expand Down
6 changes: 5 additions & 1 deletion pennylane/measurements/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Optional, Sequence, Tuple, Union

import pennylane as qml
from pennylane.math.utils import is_abstract
from pennylane.operation import DecompositionUndefinedError, EigvalsUndefinedError, Operator
from pennylane.pytrees import register_pytree
from pennylane.typing import TensorLike
Expand Down Expand Up @@ -162,6 +163,9 @@ def __init__(
# Cast sequence of measurement values to list
self.mv = obs if getattr(obs, "name", None) == "MeasurementValue" else list(obs)
self.obs = None
elif is_abstract(obs): # Catalyst program with qml.sample(m, wires=i)
self.mv = obs
self.obs = None
else:
self.obs = obs
self.mv = None
Expand Down Expand Up @@ -306,7 +310,7 @@ def wires(self):
This is the union of all the Wires objects of the measurement.
"""
if self.mv is not None:
if self.mv is not None and not is_abstract(self.mv):
if isinstance(self.mv, list):
return qml.wires.Wires.all_wires([m.wires for m in self.mv])
return self.mv.wires
Expand Down
18 changes: 18 additions & 0 deletions tests/measurements/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,24 @@ def circuit(x):
)


@pytest.mark.jax
def test_sample_with_boolean_tracer():
"""Test that qml.sample can be used with Catalyst measurement values (Boolean tracer)."""
import jax

def fun(b):
mp = qml.sample(b)

assert mp.obs is None
assert isinstance(mp.mv, jax.interpreters.partial_eval.DynamicJaxprTracer)
assert mp.mv.dtype == bool
assert mp.mv.shape == ()
assert isinstance(mp.wires, qml.wires.Wires)
assert mp.wires == ()

jax.make_jaxpr(fun)(True)


@pytest.mark.jax
@pytest.mark.parametrize(
"obs",
Expand Down
21 changes: 21 additions & 0 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,3 +716,24 @@ def f(x):
CompileError, match="Pennylane does not support the VJP function without QJIT."
):
vjp(x, dy)


class TestCatalystSample:
"""Test qml.sample with Catalyst."""

@pytest.mark.xfail(reason="requires simultaneous catalyst pr")
def test_sample_measure(self):
"""Test that qml.sample can be used with catalyst.measure."""

dev = qml.device("lightning.qubit", wires=1, shots=1)

@qml.qjit
@qml.qnode(dev)
def circuit(x):
qml.RY(x, wires=0)
m = catalyst.measure(0)
qml.PauliX(0)
return qml.sample(m)

assert circuit(0.0) == 0
assert circuit(jnp.pi) == 1

0 comments on commit fbc2a39

Please sign in to comment.