Skip to content

Commit

Permalink
Make process_samples on SampleMP jit-compatible with Tracer indices (#…
Browse files Browse the repository at this point in the history
…6211)

**Context:**
There is a use case in catalyst (currently a WIP, sampling an observable
using `measurements_from_samples` to diagonalize everything) that
results in `indices` in the `process_samples` function being an abstract
`Tracer`, and then things break that don't need to.

**Description of the Change:**
Update the jax `take` dispatch in `qml.math` to cast indices to a jax
array instead of a vanilla numpy array
  • Loading branch information
lillian542 committed Sep 6, 2024
1 parent 713a33d commit 40195fd
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 73 deletions.
16 changes: 11 additions & 5 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
`from pennylane.capture.primitives import *`.
[(#6129)](https://github.com/PennyLaneAI/pennylane/pull/6129)

* The `SampleMP.process_samples` method is updated to support using JAX tracers
for samples, allowing compatiblity with Catalyst workflows.
[(#6211)](https://github.com/PennyLaneAI/pennylane/pull/6211)

* Improve `qml.Qubitization` decomposition.
[(#6182)](https://github.com/PennyLaneAI/pennylane/pull/6182)

* The `__repr__` methods for `FermiWord` and `FermiSentence` now returns a
unique representation of the object.
[(#6167)](https://github.com/PennyLaneAI/pennylane/pull/6167)


<h3>Breaking changes 💔</h3>

* Remove support for Python 3.9.
Expand Down Expand Up @@ -61,8 +66,9 @@

This release contains contributions from (in alphabetical order):

Guillermo Alonso
Utkarsh Azad
Christina Lee
William Maxwell
Lee J. O'Riordan
Guillermo Alonso,
Utkarsh Azad,
Lillian M. A. Frederiksen,
Christina Lee,
William Maxwell,
Lee J. O'Riordan,
2 changes: 1 addition & 1 deletion pennylane/math/single_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def _to_numpy_jax(x):
"jax",
"take",
lambda x, indices, axis=None, **kwargs: _i("jax").numpy.take(
x, np.array(indices), axis=axis, **kwargs
x, _i("jax").numpy.asarray(indices), axis=axis, **kwargs
),
)
ar.register_function("jax", "coerce", lambda x: x)
Expand Down
5 changes: 4 additions & 1 deletion pennylane/measurements/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,10 @@ def process_samples(
indices = qml.math.array(indices) # Add np.array here for Jax support.
# This also covers statistics for mid-circuit measurements manipulated using
# arithmetic operators
samples = eigvals[indices]
if qml.math.is_abstract(indices):
samples = qml.math.take(eigvals, indices, like=indices)
else:
samples = eigvals[indices]

return samples if bin_size is None else samples.reshape((bin_size, -1))

Expand Down
135 changes: 69 additions & 66 deletions tests/measurements/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,14 +422,6 @@ class DummyOp(Operator): # pylint: disable=too-few-public-methods
with pytest.raises(EigvalsUndefinedError, match="Cannot compute samples of"):
qml.sample(op=DummyOp(0)).process_samples(samples=np.array([[1, 0]]), wire_order=[0])

def test_process_sample_shot_range(self):
"""Test process_samples with a shot range."""
mp = qml.sample(wires=0)

samples = np.zeros((10, 2))
out = mp.process_samples(samples, wire_order=qml.wires.Wires((0, 1)), shot_range=(0, 5))
assert qml.math.allclose(out, np.zeros((5,)))

def test_sample_allowed_with_parameter_shift(self):
"""Test that qml.sample doesn't raise an error with parameter-shift and autograd."""
dev = qml.device("default.qubit", shots=10)
Expand Down Expand Up @@ -461,76 +453,87 @@ def circuit(angle):


@pytest.mark.jax
@pytest.mark.parametrize("samples", (1, 10))
def test_jitting_with_sampling_on_subset_of_wires(samples):
"""Test case covering bug in Issue #3904. Sampling should be jit-able
when sampling occurs on a subset of wires. The bug was occuring due an improperly
set shape method."""
import jax
class TestJAXCompatibility:

jax.config.update("jax_enable_x64", True)
@pytest.mark.parametrize("samples", (1, 10))
def test_jitting_with_sampling_on_subset_of_wires(self, samples):
"""Test case covering bug in Issue #3904. Sampling should be jit-able
when sampling occurs on a subset of wires. The bug was occuring due an improperly
set shape method."""
import jax

dev = qml.device("default.qubit", wires=3, shots=samples)
jax.config.update("jax_enable_x64", True)

@qml.qnode(dev, interface="jax")
def circuit(x):
qml.RX(x, wires=0)
return qml.sample(wires=(0, 1))
dev = qml.device("default.qubit", wires=3, shots=samples)

results = jax.jit(circuit)(jax.numpy.array(0.123, dtype=jax.numpy.float64))
@qml.qnode(dev, interface="jax")
def circuit(x):
qml.RX(x, wires=0)
return qml.sample(wires=(0, 1))

expected = (2,) if samples == 1 else (samples, 2)
assert results.shape == expected
assert circuit._qfunc_output.shape(samples, 3) == (samples, 2) if samples != 1 else (2,)
results = jax.jit(circuit)(jax.numpy.array(0.123, dtype=jax.numpy.float64))

expected = (2,) if samples == 1 else (samples, 2)
assert results.shape == expected
assert circuit._qfunc_output.shape(samples, 3) == (samples, 2) if samples != 1 else (2,)

@pytest.mark.jax
def test_sample_with_boolean_tracer():
"""Test that qml.sample can be used with Catalyst measurement values (Boolean tracer)."""
import jax
def test_sample_with_boolean_tracer(self):
"""Test that qml.sample can be used with Catalyst measurement values (Boolean tracer)."""
import jax

def fun(b):
mp = qml.sample(b)
def fun(b):
mp = qml.sample(b)

assert mp.obs is None
assert isinstance(mp.mv, jax.interpreters.partial_eval.DynamicJaxprTracer)
assert mp.mv.dtype == bool
assert mp.mv.shape == ()
assert isinstance(mp.wires, qml.wires.Wires)
assert mp.wires == ()
assert mp.obs is None
assert isinstance(mp.mv, jax.interpreters.partial_eval.DynamicJaxprTracer)
assert mp.mv.dtype == bool
assert mp.mv.shape == ()
assert isinstance(mp.wires, qml.wires.Wires)
assert mp.wires == ()

jax.make_jaxpr(fun)(True)
jax.make_jaxpr(fun)(True)

@pytest.mark.parametrize(
"obs",
[
# Single observables
(qml.PauliX(0)),
(qml.PauliY(0)),
(qml.PauliZ(0)),
(qml.Hadamard(0)),
(qml.Identity(0)),
],
)
def test_jitting_with_sampling_on_different_observables(self, obs):
"""Test that jitting works when sampling observables (using their eigvals) rather than returning raw samples"""
import jax

@pytest.mark.jax
@pytest.mark.parametrize(
"obs",
[
# Single observables
(qml.PauliX(0)),
(qml.PauliY(0)),
(qml.PauliZ(0)),
(qml.Hadamard(0)),
(qml.Identity(0)),
],
)
def test_jitting_with_sampling_on_different_observables(obs):
"""Test that jitting works when sampling observables (using their eigvals) rather than returning raw samples"""
import jax

jax.config.update("jax_enable_x64", True)

dev = qml.device("default.qubit", wires=5, shots=100)

@qml.qnode(dev, interface="jax")
def circuit(x):
qml.RX(x, wires=0)
return qml.sample(obs)

results = jax.jit(circuit)(jax.numpy.array(0.123, dtype=jax.numpy.float64))

assert results.dtype == jax.numpy.float64
assert np.all([r in [1, -1] for r in results])
jax.config.update("jax_enable_x64", True)

dev = qml.device("default.qubit", wires=5, shots=100)

@qml.qnode(dev, interface="jax")
def circuit(x):
qml.RX(x, wires=0)
return qml.sample(obs)

results = jax.jit(circuit)(jax.numpy.array(0.123, dtype=jax.numpy.float64))

assert results.dtype == jax.numpy.float64
assert np.all([r in [1, -1] for r in results])

def test_process_samples_with_jax_tracer(self):
"""Test that qml.sample can be used when samples is a JAX Tracer"""

import jax

def f(samples):
return qml.sample(op=2 * qml.X(0)).process_samples(
samples, wire_order=qml.wires.Wires((0, 1))
)

samples = jax.numpy.zeros((10, 2), dtype=int)
jax.jit(f)(samples)


class TestSampleProcessCounts:
Expand Down

0 comments on commit 40195fd

Please sign in to comment.