Skip to content

Commit

Permalink
Merge branch 'master' into fermi-commute
Browse files Browse the repository at this point in the history
  • Loading branch information
willjmax committed Sep 18, 2024
2 parents 1aa1f56 + d7db6b4 commit 45842f2
Showing 1 changed file with 49 additions and 215 deletions.
264 changes: 49 additions & 215 deletions tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest

import pennylane as qml
from pennylane.devices.qubit.apply_operation import MidMeasureMP, apply_mid_measure
from pennylane.devices.qubit.apply_operation import MidMeasureMP
from pennylane.devices.qubit.simulate import combine_measurements_core, measurement_with_no_shots
from pennylane.transforms.dynamic_one_shot import fill_in_value

Expand Down Expand Up @@ -53,16 +53,9 @@ def test_measurement_with_no_shots():
assert all(np.isnan(probs).tolist())


def test_apply_mid_measure():
"""Test that apply_mid_measure raises if applied to a batched state."""
with pytest.raises(ValueError, match="MidMeasureMP cannot be applied to batched states."):
_ = apply_mid_measure(
MidMeasureMP(0), np.zeros((2, 2)), is_state_batched=True, mid_measurements={}
)


@pytest.mark.parametrize("obs", ["mid-meas", "pauli"])
@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"])
def test_all_invalid_shots_circuit(mcm_method):
def test_all_invalid_shots_circuit(obs, mcm_method):
"""Test that circuits in which all shots mismatch with post-selection conditions return the same answer as ``defer_measurements``."""
dev = get_device()
dev_shots = get_device(shots=10)
Expand All @@ -71,9 +64,13 @@ def circuit_op():
m = qml.measure(0, postselect=1)
qml.cond(m, qml.PauliX)(1)
return (
qml.expval(op=qml.PauliZ(1)),
qml.probs(op=qml.PauliY(0) @ qml.PauliZ(1)),
qml.var(op=qml.PauliZ(1)),
(
qml.expval(op=qml.PauliZ(1)),
qml.probs(op=qml.PauliY(0) @ qml.PauliZ(1)),
qml.var(op=qml.PauliZ(1)),
)
if obs == "pauli"
else (qml.expval(op=m), qml.probs(op=m), qml.var(op=m))
)

res1 = qml.QNode(circuit_op, dev, mcm_method="deferred")()
Expand All @@ -84,19 +81,6 @@ def circuit_op():
assert np.all(np.isnan(r1))
assert np.all(np.isnan(r2))

def circuit_mcm():
m = qml.measure(0, postselect=1)
qml.cond(m, qml.PauliX)(1)
return qml.expval(op=m), qml.probs(op=m), qml.var(op=m)

res1 = qml.QNode(circuit_mcm, dev, mcm_method="deferred")()
res2 = qml.QNode(circuit_mcm, dev_shots, mcm_method=mcm_method)()
for r1, r2 in zip(res1, res2):
if isinstance(r1, Sequence):
assert len(r1) == len(r2)
assert np.all(np.isnan(r1))
assert np.all(np.isnan(r2))


@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"])
def test_unsupported_measurement(mcm_method):
Expand All @@ -118,41 +102,6 @@ def func(x, y):
func(*params)


@pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"])
def test_tree_traversal_postselect_mode(postselect_mode):
"""Test that invalid shots are discarded if requested"""
shots = 100
dev = qml.device("default.qubit", shots=shots)

@qml.qnode(dev, mcm_method="tree-traversal", postselect_mode=postselect_mode)
def f(x):
qml.RX(x, 0)
_ = qml.measure(0, postselect=1)
return qml.sample(wires=[0, 1])

res = f(np.pi / 2)

if postselect_mode == "hw-like":
assert len(res) < shots
else:
assert len(res) == shots
assert np.all(res != np.iinfo(np.int32).min)


def test_deep_circuit():
"""Tests that DefaultQubit handles a circuit with more than 1000 mid-circuit measurements."""

dev = qml.device("default.qubit", shots=10)

def func(x):
for _ in range(600):
qml.RX(x, wires=0)
m0 = qml.measure(0)
return qml.expval(qml.PauliY(0)), qml.expval(m0)

_ = qml.QNode(func, dev, mcm_method="tree-traversal")(0.1234)


# pylint: disable=unused-argument
def obs_tape(x, y, z, reset=False, postselect=None):
qml.RX(x, 0)
Expand All @@ -174,69 +123,16 @@ def obs_tape(x, y, z, reset=False, postselect=None):

@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"])
@pytest.mark.parametrize("shots", [None, 5500, [5500, 5501]])
@pytest.mark.parametrize("postselect", [None, 0, 1])
@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var])
@pytest.mark.parametrize(
"meas_obj",
[qml.PauliZ(0), qml.PauliY(1), [0], [0, 1], [1, 0], "mcm", "composite_mcm", "mcm_list"],
"params",
[
[np.pi / 2.5, np.pi / 3, -np.pi / 3.5],
[[np.pi / 2.5, -np.pi / 3.5], [np.pi / 4.5, np.pi / 3.2], [np.pi, -np.pi / 0.5]],
],
)
def test_simple_dynamic_circuit(mcm_method, shots, measure_f, postselect, meas_obj):
"""Tests that DefaultQubit handles a simple dynamic circuit with the following measurements:
* qml.counts with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list
* qml.expval with obs (comp basis or not), MCM, f(MCM), MCM list
* qml.probs with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list
* qml.sample with obs (comp basis or not), single wire, multiple wires (ordered/unordered), MCM, f(MCM), MCM list
* qml.var with obs (comp basis or not), MCM, f(MCM), MCM list
The above combinations should work for finite shots, shot vectors and post-selecting of either the 0 or 1 branch.
"""

if (
isinstance(meas_obj, (qml.Z, qml.Y))
and measure_f == qml.var
and mcm_method == "tree-traversal"
and not qml.operation.active_new_opmath()
):
pytest.xfail(
"The tree-traversal method does not work with legacy opmath with "
"`qml.var` of pauli observables in the circuit."
)

if mcm_method == "one-shot" and shots is None:
pytest.skip("`mcm_method='one-shot'` is incompatible with analytic mode (`shots=None`)")

if measure_f in (qml.expval, qml.var) and (
isinstance(meas_obj, list) or meas_obj == "mcm_list"
):
pytest.skip("Can't use wires/mcm lists with var or expval")

if measure_f in (qml.counts, qml.sample) and shots is None:
pytest.skip("Can't measure counts/sample in analytic mode (`shots=None`)")

dev = get_device(shots=shots)
params = [np.pi / 2.5, np.pi / 3, -np.pi / 3.5]

def func(x, y, z):
m0, m1 = obs_tape(x, y, z, postselect=postselect)
mid_measure = (
m0 if meas_obj == "mcm" else (0.5 * m0 if meas_obj == "composite_mcm" else [m0, m1])
)
measurement_key = "wires" if isinstance(meas_obj, list) else "op"
measurement_value = mid_measure if isinstance(meas_obj, str) else meas_obj
return measure_f(**{measurement_key: measurement_value})

results0 = qml.QNode(func, dev, mcm_method=mcm_method)(*params)
results1 = qml.QNode(func, dev, mcm_method="deferred")(*params)

mcm_utils.validate_measurements(measure_f, shots, results1, results0)


@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"])
@pytest.mark.parametrize("shots", [None, 5000])
@pytest.mark.parametrize("postselect", [None, 0, 1])
@pytest.mark.parametrize("reset", [False, True])
def test_multiple_measurements_and_reset(mcm_method, shots, postselect, reset):
def test_multiple_measurements_and_reset(mcm_method, shots, params, postselect, reset):
"""Tests that DefaultQubit handles a circuit with a single mid-circuit measurement with reset
and a conditional gate. Multiple measurements of the mid-circuit measurement value are
performed. This function also tests `reset` parametrizing over the parameter."""
Expand All @@ -250,8 +146,11 @@ def test_multiple_measurements_and_reset(mcm_method, shots, postselect, reset):
if mcm_method == "one-shot" and shots is None:
pytest.skip("`mcm_method='one-shot'` is incompatible with analytic mode (`shots=None`)")

batch_size = len(params[0]) if isinstance(params[0], list) else None
if batch_size is not None and shots is not None and postselect is not None:
pytest.skip("Postselection with samples doesn't work with broadcasting")

dev = get_device(shots=shots)
params = [np.pi / 2.5, np.pi / 3, -np.pi / 3.5]
obs = qml.PauliY(1)
state = qml.math.zeros((4,))
state[0] = 1.0
Expand Down Expand Up @@ -286,27 +185,34 @@ def func(x, y, z):
if shots is None
else [qml.counts, qml.expval, qml.probs, qml.sample, qml.var]
)
for measure_f, r1, r0 in zip(measurements, results1, results0):
mcm_utils.validate_measurements(measure_f, shots, r1, r0)

if not isinstance(shots, list):
shots, results0, results1 = [shots], [results0], [results1]

for shot, res1, res0 in zip(shots, results1, results0):
for measure_f, r1, r0 in zip(measurements, res1, res0):
if shots is None and measure_f in [qml.expval, qml.probs] and batch_size is not None:
r0 = qml.math.squeeze(r0)
mcm_utils.validate_measurements(measure_f, shot, r1, r0, batch_size=batch_size)


@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"])
@pytest.mark.parametrize("shots", [None, 3000])
@pytest.mark.parametrize("shots", [None, 5000, [5000, 5001]])
@pytest.mark.parametrize(
"mcm_f",
"mcm_name, mcm_func",
[
lambda x: x * -1,
lambda x: x * 1,
lambda x: x * 2,
lambda x: 1 - x,
lambda x: x + 1,
lambda x: x & 3,
"mix",
"list",
("single", lambda x: x * -1),
("single", lambda x: x * 2),
("single", lambda x: 1 - x),
("single", lambda x: x & 3),
("mix", lambda x, y: x == y),
("mix", lambda x, y: 4 * x + 2 * y),
("all", lambda x, y, z: [x, y, z]),
("all", lambda x, y, z: (x - 2 * y) * z + 7),
],
)
@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var])
def test_composite_mcms(mcm_method, shots, mcm_f, measure_f):
def test_composite_mcms(mcm_method, shots, mcm_name, mcm_func, measure_f):
"""Tests that DefaultQubit handles a circuit with a composite mid-circuit measurement and a
conditional gate. A single measurement of a composite mid-circuit measurement is performed
at the end."""
Expand All @@ -317,73 +223,39 @@ def test_composite_mcms(mcm_method, shots, mcm_f, measure_f):
if measure_f in (qml.counts, qml.sample) and shots is None:
pytest.skip("Can't measure counts/sample in analytic mode (`shots=None`)")

if measure_f in (qml.expval, qml.var) and (mcm_f in ("list", "mix")):
if measure_f in (qml.expval, qml.var) and mcm_name in ["mix", "all"]:
pytest.skip(
"expval/var does not support measuring sequences of measurements or observables."
)

if measure_f == qml.probs and mcm_f == "mix":
if measure_f in (qml.probs,) and mcm_name in ["mix", "all"]:
pytest.skip(
"Cannot use qml.probs() when measuring multiple mid-circuit measurements collected using arithmetic operators."
)

dev = get_device(shots=shots)
param = np.pi / 3
param = qml.numpy.array([np.pi / 3, np.pi / 6])

def func(x):
def func(x, y):
qml.RX(x, 0)
m0 = qml.measure(0)
qml.RX(0.5 * x, 1)
m1 = qml.measure(1)
qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0)
m2 = qml.measure(0)
obs = (
(m0 - 2 * m1) * m2 + 7
if mcm_f == "mix"
else ([m0, m1, m2] if mcm_f == "list" else mcm_f(m2))
mcm_func(m2)
if mcm_name == "single"
else (mcm_func(m0, m1) if mcm_name == "mix" else mcm_func(m0, m1, m2))
)
return measure_f(op=obs)

results0 = qml.QNode(func, dev, mcm_method=mcm_method)(param)
results1 = qml.QNode(func, dev, mcm_method="deferred")(param)
results0 = qml.QNode(func, dev, mcm_method=mcm_method)(*param)
results1 = qml.QNode(func, dev, mcm_method="deferred")(*param)

mcm_utils.validate_measurements(measure_f, shots, results1, results0)


@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"])
@pytest.mark.parametrize(
"mcm_f",
[
lambda x, y: x + y,
lambda x, y: x - 7 * y,
lambda x, y: x & y,
lambda x, y: x == y,
lambda x, y: 4.0 * x + 2.0 * y,
],
)
def test_counts_return_type(mcm_method, mcm_f):
"""Tests that DefaultQubit returns the same keys for ``qml.counts`` measurements with ``dynamic_one_shot`` and ``defer_measurements``."""

shots = 500

dev = get_device(shots=shots)
param = np.pi / 3

def func(x):
qml.RX(x, 0)
m0 = qml.measure(0)
qml.RX(0.5 * x, 1)
m1 = qml.measure(1)
qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0)
return qml.counts(op=mcm_f(m0, m1))

results0 = qml.QNode(func, dev, mcm_method=mcm_method)(param)
results1 = qml.QNode(func, dev, mcm_method="deferred")(param)

for r1, r0 in zip(results1.keys(), results0.keys()):
assert r1 == r0


@pytest.mark.parametrize("shots", [5000])
@pytest.mark.parametrize("postselect", [None, 0, 1])
@pytest.mark.parametrize("reset", [False, True])
Expand Down Expand Up @@ -421,44 +293,6 @@ def func(x, y):
assert np.allclose(grad1, grad2, atol=0.01, rtol=0.3)


@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"])
@pytest.mark.parametrize("shots", [None, 5500, [5500, 5501]])
@pytest.mark.parametrize("postselect", [None, 0])
@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample])
def test_broadcasting_qnode(mcm_method, shots, postselect, measure_f):
"""Test that executing qnodes with broadcasting works as expected"""

if mcm_method == "one-shot" and shots is None:
pytest.skip("`mcm_method='one-shot'` is incompatible with analytic mode (`shots=None`)")

if measure_f in (qml.counts, qml.sample) and shots is None:
pytest.skip("Can't measure counts/sample in analytic mode (`shots=None`)")

if measure_f is qml.sample and postselect is not None:
pytest.skip("Postselection with samples doesn't work with broadcasting")

dev = get_device(shots=shots)
param = [[np.pi / 3, np.pi / 4], [np.pi / 6, 2 * np.pi / 3]]
obs = qml.PauliZ(0) @ qml.PauliZ(1)

def func(x, y):
obs_tape(x, y, None, postselect=postselect)
return measure_f(op=obs)

results0 = qml.QNode(func, dev, mcm_method=mcm_method)(*param)
results1 = qml.QNode(func, dev, mcm_method="deferred")(*param)

mcm_utils.validate_measurements(measure_f, shots, results1, results0, batch_size=2)

if measure_f is qml.sample and postselect is None:
for i in range(2): # batch_size
if isinstance(shots, list):
for s, r1, r2 in zip(shots, results1, results0):
assert len(r1[i]) == len(r2[i]) == s
else:
assert len(results1[i]) == len(results0[i]) == shots


@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"])
def test_sample_with_broadcasting_and_postselection_error(mcm_method):
"""Test that an error is raised if returning qml.sample if postselecting with broadcasting"""
Expand Down

0 comments on commit 45842f2

Please sign in to comment.