Skip to content

Commit

Permalink
Support broadcasting in simulate and preprocess (#4244)
Browse files Browse the repository at this point in the history
* Support broadcasting in state measurements

* docs for is_state_batched

* Support broadcasting in sample measurements

* Support broadcastingin simulate and preprocess

---------

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
  • Loading branch information
eddddddy and timmysilv committed Jun 16, 2023
1 parent bc31148 commit 9256e47
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 13 deletions.
15 changes: 10 additions & 5 deletions pennylane/devices/qubit/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dataclasses import replace
from typing import Generator, Callable, Tuple, Union
import warnings
from functools import partial

import pennylane as qml

Expand Down Expand Up @@ -241,7 +242,7 @@ def expand_fn(circuit: qml.tape.QuantumScript) -> qml.tape.QuantumScript:


def batch_transform(
circuit: qml.tape.QuantumScript,
circuit: qml.tape.QuantumScript, execution_config: ExecutionConfig = DefaultExecutionConfig
) -> Tuple[Tuple[qml.tape.QuantumScript], PostprocessingFn]:
"""Apply a differentiable batch transform for preprocessing a circuit
prior to execution.
Expand All @@ -258,15 +259,18 @@ def batch_transform(
Args:
circuit (.QuantumTape): the circuit to preprocess
execution_config (.ExecutionConfig): execution configuration with configurable
options for the execution.
Returns:
tuple[Sequence[.QuantumTape], callable]: Returns a tuple containing
the sequence of circuits to be executed, and a post-processing function
to be applied to the list of evaluated circuit results.
"""
# Check whether the circuit was broadcasted
if circuit.batch_size is None:
# If the circuit wasn't broadcasted, no action required
# Check whether the circuit was broadcasted or if the diff method is anything other than adjoint
if circuit.batch_size is None or execution_config.gradient_method != "adjoint":
# If the circuit wasn't broadcasted, or if built-in PennyLane broadcasting
# can be used, then no action required
circuits = [circuit]

def batch_fn(res: ResultBatch) -> Result:
Expand Down Expand Up @@ -333,6 +337,7 @@ def preprocess(
if isinstance(circuit_or_error, DeviceError):
raise circuit_or_error # it's an error

circuits, batch_fn = qml.transforms.map_batch_transform(batch_transform, circuits)
transform = partial(batch_transform, execution_config=execution_config)
circuits, batch_fn = qml.transforms.map_batch_transform(transform, circuits)

return circuits, batch_fn, _update_config(execution_config)
19 changes: 16 additions & 3 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,26 @@ def simulate(circuit: qml.tape.QuantumScript, rng=None, debugger=None) -> Result

state = create_initial_state(circuit.wires, circuit._prep[0] if circuit._prep else None)

# initial state is batched only if the state preparation (if it exists) is batched
is_state_batched = False
if circuit._prep and circuit._prep[0].batch_size is not None:
is_state_batched = True

for op in circuit._ops:
state = apply_operation(op, state, debugger=debugger)
state = apply_operation(op, state, is_state_batched=is_state_batched, debugger=debugger)

# new state is batched if i) the old state is batched, or ii) the new op adds a batch dim
is_state_batched = is_state_batched or op.batch_size is not None

if not circuit.shots:
# analytic case

if len(circuit.measurements) == 1:
return measure(circuit.measurements[0], state)

return tuple(measure(mp, state) for mp in circuit.measurements)
return tuple(
measure(mp, state, is_state_batched=is_state_batched) for mp in circuit.measurements
)

# finite-shot case

Expand All @@ -75,7 +85,10 @@ def simulate(circuit: qml.tape.QuantumScript, rng=None, debugger=None) -> Result

rng = default_rng(rng)
results = tuple(
measure_with_samples(mp, state, shots=circuit.shots, rng=rng) for mp in circuit.measurements
measure_with_samples(
mp, state, shots=circuit.shots, is_state_batched=is_state_batched, rng=rng
)
for mp in circuit.measurements
)

# no shot vector
Expand Down
6 changes: 5 additions & 1 deletion tests/devices/experimental/test_default_qubit_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,11 @@ def test_broadcasted_parameter():
dev = DefaultQubit2()
x = np.array([0.536, 0.894])
qs = qml.tape.QuantumScript([qml.RX(x, 0)], [qml.expval(qml.PauliZ(0))])
batch, post_processing_fn, config = dev.preprocess(qs)

config = ExecutionConfig()
config.gradient_method = "adjoint"
batch, post_processing_fn, config = dev.preprocess(qs, config)

assert len(batch) == 2
results = dev.execute(batch, config)
processed_results = post_processing_fn(results)
Expand Down
92 changes: 88 additions & 4 deletions tests/devices/qubit/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,33 @@ def test_batch_transform_no_batching(self):
input = (("a", "b"), "c")
assert batch_fn(input) == input[0]

def test_batch_transform_broadcast(self):
"""Test that batch_transform splits broadcasted tapes correctly."""
def test_batch_transform_broadcast_not_adjoint(self):
"""Test that batch_transform does nothing when batching is required but
internal PennyLane broadcasting can be used (diff method != adjoint)"""
ops = [qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX([np.pi, np.pi / 2], wires=1)]
measurements = [qml.expval(qml.PauliZ(1))]
tape = QuantumScript(ops=ops, measurements=measurements)

tapes, batch_fn = batch_transform(tape)

assert len(tapes) == 1
for op, expected in zip(tapes[0].circuit, ops + measurements):
assert qml.equal(op, expected)

input = ([[1, 2], [3, 4]],)
assert np.array_equal(batch_fn(input), np.array([[1, 2], [3, 4]]))

def test_batch_transform_broadcast_adjoint(self):
"""Test that batch_transform splits broadcasted tapes correctly when
the diff method is adjoint"""
ops = [qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX([np.pi, np.pi / 2], wires=1)]
measurements = [qml.expval(qml.PauliZ(1))]
tape = QuantumScript(ops=ops, measurements=measurements)

execution_config = ExecutionConfig()
execution_config.gradient_method = "adjoint"

tapes, batch_fn = batch_transform(tape, execution_config=execution_config)
expected_ops = [
[qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX(np.pi, wires=1)],
[qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX(np.pi / 2, wires=1)],
Expand Down Expand Up @@ -502,7 +522,7 @@ def test_config_choices_for_adjoint(self):
assert new_config.use_device_gradient
assert new_config.grad_on_execution

def test_preprocess_batch_transform(self):
def test_preprocess_batch_transform_not_adjoint(self):
"""Test that preprocess returns the correct tapes when a batch transform
is needed."""
ops = [qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX([np.pi, np.pi / 2], wires=1)]
Expand All @@ -514,6 +534,35 @@ def test_preprocess_batch_transform(self):
]

res_tapes, batch_fn, _ = preprocess(tapes)

assert len(res_tapes) == 2
for i, t in enumerate(res_tapes):
for op, expected_op in zip(t.operations, ops):
assert qml.equal(op, expected_op)
assert len(t.measurements) == 1
if i == 0:
assert qml.equal(t.measurements[0], measurements[0])
else:
assert qml.equal(t.measurements[0], measurements[1])

input = ([[1, 2], [3, 4]], [[5, 6], [7, 8]])
assert np.array_equal(batch_fn(input), np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]))

def test_preprocess_batch_transform_adjoint(self):
"""Test that preprocess returns the correct tapes when a batch transform
is needed."""
ops = [qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX([np.pi, np.pi / 2], wires=1)]
# Need to specify grouping type to transform tape
measurements = [qml.expval(qml.PauliX(0)), qml.expval(qml.PauliZ(1))]
tapes = [
QuantumScript(ops=ops, measurements=[measurements[0]]),
QuantumScript(ops=ops, measurements=[measurements[1]]),
]

execution_config = ExecutionConfig()
execution_config.gradient_method = "adjoint"

res_tapes, batch_fn, _ = preprocess(tapes, execution_config=execution_config)
expected_ops = [
[qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX(np.pi, wires=1)],
[qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX(np.pi / 2, wires=1)],
Expand Down Expand Up @@ -552,7 +601,7 @@ def test_preprocess_expand(self):
input = (("a", "b"), "c", "d")
assert batch_fn(input) == [("a", "b"), "c"]

def test_preprocess_split_and_expand(self):
def test_preprocess_split_and_expand_not_adjoint(self):
"""Test that preprocess returns the correct tapes when splitting and expanding
is needed."""
ops = [qml.Hadamard(0), NoMatOp(1), qml.RX([np.pi, np.pi / 2], wires=1)]
Expand All @@ -564,6 +613,41 @@ def test_preprocess_split_and_expand(self):
]

res_tapes, batch_fn, _ = preprocess(tapes)
expected_ops = [
qml.Hadamard(0),
qml.PauliX(1),
qml.PauliY(1),
qml.RX([np.pi, np.pi / 2], wires=1),
]

assert len(res_tapes) == 2
for i, t in enumerate(res_tapes):
for op, expected_op in zip(t.operations, expected_ops):
assert qml.equal(op, expected_op)
assert len(t.measurements) == 1
if i == 0:
assert qml.equal(t.measurements[0], measurements[0])
else:
assert qml.equal(t.measurements[0], measurements[1])

input = ([[1, 2], [3, 4]], [[5, 6], [7, 8]])
assert np.array_equal(batch_fn(input), np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]))

def test_preprocess_split_and_expand_adjoint(self):
"""Test that preprocess returns the correct tapes when splitting and expanding
is needed."""
ops = [qml.Hadamard(0), NoMatOp(1), qml.RX([np.pi, np.pi / 2], wires=1)]
# Need to specify grouping type to transform tape
measurements = [qml.expval(qml.PauliX(0)), qml.expval(qml.PauliZ(1))]
tapes = [
QuantumScript(ops=ops, measurements=[measurements[0]]),
QuantumScript(ops=ops, measurements=[measurements[1]]),
]

execution_config = ExecutionConfig()
execution_config.gradient_method = "adjoint"

res_tapes, batch_fn, _ = preprocess(tapes, execution_config=execution_config)
expected_ops = [
[qml.Hadamard(0), qml.PauliX(1), qml.PauliY(1), qml.RX(np.pi, wires=1)],
[qml.Hadamard(0), qml.PauliX(1), qml.PauliY(1), qml.RX(np.pi / 2, wires=1)],
Expand Down
74 changes: 74 additions & 0 deletions tests/devices/qubit/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,80 @@ def test_tf_results_and_backprop(self):
assert qml.math.allclose(grad1[0], -tf.sin(phi))


class TestBroadcasting:
"""Test that simulate works with broadcasted parameters"""

def test_broadcasted_prep_state(self):
"""Test that simulate works for state measurements
when the state prep has broadcasted parameters"""
x = np.array(1.2)

ops = [qml.RY(x, wires=0), qml.CNOT(wires=[0, 1])]
measurements = [qml.expval(qml.PauliZ(i)) for i in range(2)]
prep = [qml.QubitStateVector(np.eye(4), wires=[0, 1])]

qs = qml.tape.QuantumScript(ops, measurements, prep)
res = simulate(qs)

assert isinstance(res, tuple)
assert len(res) == 2
assert np.allclose(res[0], np.array([np.cos(x), np.cos(x), -np.cos(x), -np.cos(x)]))
assert np.allclose(res[1], np.array([np.cos(x), -np.cos(x), -np.cos(x), np.cos(x)]))

def test_broadcasted_op_state(self):
"""Test that simulate works for state measurements
when an operation has broadcasted parameters"""
x = np.array([0.8, 1.0, 1.2, 1.4])

ops = [qml.PauliX(wires=1), qml.RY(x, wires=0), qml.CNOT(wires=[0, 1])]
measurements = [qml.expval(qml.PauliZ(i)) for i in range(2)]

qs = qml.tape.QuantumScript(ops, measurements)
res = simulate(qs)

assert isinstance(res, tuple)
assert len(res) == 2
assert np.allclose(res[0], np.cos(x))
assert np.allclose(res[1], -np.cos(x))

def test_broadcasted_prep_sample(self):
"""Test that simulate works for sample measurements
when the state prep has broadcasted parameters"""
x = np.array(1.2)

ops = [qml.RY(x, wires=0), qml.CNOT(wires=[0, 1])]
measurements = [qml.expval(qml.PauliZ(i)) for i in range(2)]
prep = [qml.QubitStateVector(np.eye(4), wires=[0, 1])]

qs = qml.tape.QuantumScript(ops, measurements, prep, shots=qml.measurements.Shots(10000))
res = simulate(qs, rng=123)

assert isinstance(res, tuple)
assert len(res) == 2
assert np.allclose(
res[0], np.array([np.cos(x), np.cos(x), -np.cos(x), -np.cos(x)]), atol=0.05
)
assert np.allclose(
res[1], np.array([np.cos(x), -np.cos(x), -np.cos(x), np.cos(x)]), atol=0.05
)

def test_broadcasted_op_sample(self):
"""Test that simulate works for sample measurements
when an operation has broadcasted parameters"""
x = np.array([0.8, 1.0, 1.2, 1.4])

ops = [qml.PauliX(wires=1), qml.RY(x, wires=0), qml.CNOT(wires=[0, 1])]
measurements = [qml.expval(qml.PauliZ(i)) for i in range(2)]

qs = qml.tape.QuantumScript(ops, measurements, shots=qml.measurements.Shots(10000))
res = simulate(qs, rng=123)

assert isinstance(res, tuple)
assert len(res) == 2
assert np.allclose(res[0], np.cos(x), atol=0.05)
assert np.allclose(res[1], -np.cos(x), atol=0.05)


class TestDebugger:
"""Tests that the debugger works for a simple circuit"""

Expand Down

0 comments on commit 9256e47

Please sign in to comment.