From 9256e470051df39712423d9052106c51726a6fad Mon Sep 17 00:00:00 2001 From: Edward Jiang <34989448+eddddddy@users.noreply.github.com> Date: Fri, 16 Jun 2023 17:47:57 -0400 Subject: [PATCH] Support broadcasting in `simulate` and `preprocess` (#4244) * Support broadcasting in state measurements * docs for is_state_batched * Support broadcasting in sample measurements * Support broadcastingin simulate and preprocess --------- Co-authored-by: Matthew Silverman --- pennylane/devices/qubit/preprocess.py | 15 ++- pennylane/devices/qubit/simulate.py | 19 +++- .../experimental/test_default_qubit_2.py | 6 +- tests/devices/qubit/test_preprocess.py | 92 ++++++++++++++++++- tests/devices/qubit/test_simulate.py | 74 +++++++++++++++ 5 files changed, 193 insertions(+), 13 deletions(-) diff --git a/pennylane/devices/qubit/preprocess.py b/pennylane/devices/qubit/preprocess.py index 48773809b66..f1518193524 100644 --- a/pennylane/devices/qubit/preprocess.py +++ b/pennylane/devices/qubit/preprocess.py @@ -18,6 +18,7 @@ from dataclasses import replace from typing import Generator, Callable, Tuple, Union import warnings +from functools import partial import pennylane as qml @@ -241,7 +242,7 @@ def expand_fn(circuit: qml.tape.QuantumScript) -> qml.tape.QuantumScript: def batch_transform( - circuit: qml.tape.QuantumScript, + circuit: qml.tape.QuantumScript, execution_config: ExecutionConfig = DefaultExecutionConfig ) -> Tuple[Tuple[qml.tape.QuantumScript], PostprocessingFn]: """Apply a differentiable batch transform for preprocessing a circuit prior to execution. @@ -258,15 +259,18 @@ def batch_transform( Args: circuit (.QuantumTape): the circuit to preprocess + execution_config (.ExecutionConfig): execution configuration with configurable + options for the execution. Returns: tuple[Sequence[.QuantumTape], callable]: Returns a tuple containing the sequence of circuits to be executed, and a post-processing function to be applied to the list of evaluated circuit results. """ - # Check whether the circuit was broadcasted - if circuit.batch_size is None: - # If the circuit wasn't broadcasted, no action required + # Check whether the circuit was broadcasted or if the diff method is anything other than adjoint + if circuit.batch_size is None or execution_config.gradient_method != "adjoint": + # If the circuit wasn't broadcasted, or if built-in PennyLane broadcasting + # can be used, then no action required circuits = [circuit] def batch_fn(res: ResultBatch) -> Result: @@ -333,6 +337,7 @@ def preprocess( if isinstance(circuit_or_error, DeviceError): raise circuit_or_error # it's an error - circuits, batch_fn = qml.transforms.map_batch_transform(batch_transform, circuits) + transform = partial(batch_transform, execution_config=execution_config) + circuits, batch_fn = qml.transforms.map_batch_transform(transform, circuits) return circuits, batch_fn, _update_config(execution_config) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 86ec36ea66c..159fdf1497e 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -57,8 +57,16 @@ def simulate(circuit: qml.tape.QuantumScript, rng=None, debugger=None) -> Result state = create_initial_state(circuit.wires, circuit._prep[0] if circuit._prep else None) + # initial state is batched only if the state preparation (if it exists) is batched + is_state_batched = False + if circuit._prep and circuit._prep[0].batch_size is not None: + is_state_batched = True + for op in circuit._ops: - state = apply_operation(op, state, debugger=debugger) + state = apply_operation(op, state, is_state_batched=is_state_batched, debugger=debugger) + + # new state is batched if i) the old state is batched, or ii) the new op adds a batch dim + is_state_batched = is_state_batched or op.batch_size is not None if not circuit.shots: # analytic case @@ -66,7 +74,9 @@ def simulate(circuit: qml.tape.QuantumScript, rng=None, debugger=None) -> Result if len(circuit.measurements) == 1: return measure(circuit.measurements[0], state) - return tuple(measure(mp, state) for mp in circuit.measurements) + return tuple( + measure(mp, state, is_state_batched=is_state_batched) for mp in circuit.measurements + ) # finite-shot case @@ -75,7 +85,10 @@ def simulate(circuit: qml.tape.QuantumScript, rng=None, debugger=None) -> Result rng = default_rng(rng) results = tuple( - measure_with_samples(mp, state, shots=circuit.shots, rng=rng) for mp in circuit.measurements + measure_with_samples( + mp, state, shots=circuit.shots, is_state_batched=is_state_batched, rng=rng + ) + for mp in circuit.measurements ) # no shot vector diff --git a/tests/devices/experimental/test_default_qubit_2.py b/tests/devices/experimental/test_default_qubit_2.py index f5313356201..1d5e03fb892 100644 --- a/tests/devices/experimental/test_default_qubit_2.py +++ b/tests/devices/experimental/test_default_qubit_2.py @@ -1235,7 +1235,11 @@ def test_broadcasted_parameter(): dev = DefaultQubit2() x = np.array([0.536, 0.894]) qs = qml.tape.QuantumScript([qml.RX(x, 0)], [qml.expval(qml.PauliZ(0))]) - batch, post_processing_fn, config = dev.preprocess(qs) + + config = ExecutionConfig() + config.gradient_method = "adjoint" + batch, post_processing_fn, config = dev.preprocess(qs, config) + assert len(batch) == 2 results = dev.execute(batch, config) processed_results = post_processing_fn(results) diff --git a/tests/devices/qubit/test_preprocess.py b/tests/devices/qubit/test_preprocess.py index 4b3ed1b1053..cee238fcabc 100644 --- a/tests/devices/qubit/test_preprocess.py +++ b/tests/devices/qubit/test_preprocess.py @@ -354,13 +354,33 @@ def test_batch_transform_no_batching(self): input = (("a", "b"), "c") assert batch_fn(input) == input[0] - def test_batch_transform_broadcast(self): - """Test that batch_transform splits broadcasted tapes correctly.""" + def test_batch_transform_broadcast_not_adjoint(self): + """Test that batch_transform does nothing when batching is required but + internal PennyLane broadcasting can be used (diff method != adjoint)""" ops = [qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX([np.pi, np.pi / 2], wires=1)] measurements = [qml.expval(qml.PauliZ(1))] tape = QuantumScript(ops=ops, measurements=measurements) tapes, batch_fn = batch_transform(tape) + + assert len(tapes) == 1 + for op, expected in zip(tapes[0].circuit, ops + measurements): + assert qml.equal(op, expected) + + input = ([[1, 2], [3, 4]],) + assert np.array_equal(batch_fn(input), np.array([[1, 2], [3, 4]])) + + def test_batch_transform_broadcast_adjoint(self): + """Test that batch_transform splits broadcasted tapes correctly when + the diff method is adjoint""" + ops = [qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX([np.pi, np.pi / 2], wires=1)] + measurements = [qml.expval(qml.PauliZ(1))] + tape = QuantumScript(ops=ops, measurements=measurements) + + execution_config = ExecutionConfig() + execution_config.gradient_method = "adjoint" + + tapes, batch_fn = batch_transform(tape, execution_config=execution_config) expected_ops = [ [qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX(np.pi, wires=1)], [qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX(np.pi / 2, wires=1)], @@ -502,7 +522,7 @@ def test_config_choices_for_adjoint(self): assert new_config.use_device_gradient assert new_config.grad_on_execution - def test_preprocess_batch_transform(self): + def test_preprocess_batch_transform_not_adjoint(self): """Test that preprocess returns the correct tapes when a batch transform is needed.""" ops = [qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX([np.pi, np.pi / 2], wires=1)] @@ -514,6 +534,35 @@ def test_preprocess_batch_transform(self): ] res_tapes, batch_fn, _ = preprocess(tapes) + + assert len(res_tapes) == 2 + for i, t in enumerate(res_tapes): + for op, expected_op in zip(t.operations, ops): + assert qml.equal(op, expected_op) + assert len(t.measurements) == 1 + if i == 0: + assert qml.equal(t.measurements[0], measurements[0]) + else: + assert qml.equal(t.measurements[0], measurements[1]) + + input = ([[1, 2], [3, 4]], [[5, 6], [7, 8]]) + assert np.array_equal(batch_fn(input), np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])) + + def test_preprocess_batch_transform_adjoint(self): + """Test that preprocess returns the correct tapes when a batch transform + is needed.""" + ops = [qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX([np.pi, np.pi / 2], wires=1)] + # Need to specify grouping type to transform tape + measurements = [qml.expval(qml.PauliX(0)), qml.expval(qml.PauliZ(1))] + tapes = [ + QuantumScript(ops=ops, measurements=[measurements[0]]), + QuantumScript(ops=ops, measurements=[measurements[1]]), + ] + + execution_config = ExecutionConfig() + execution_config.gradient_method = "adjoint" + + res_tapes, batch_fn, _ = preprocess(tapes, execution_config=execution_config) expected_ops = [ [qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX(np.pi, wires=1)], [qml.Hadamard(0), qml.CNOT([0, 1]), qml.RX(np.pi / 2, wires=1)], @@ -552,7 +601,7 @@ def test_preprocess_expand(self): input = (("a", "b"), "c", "d") assert batch_fn(input) == [("a", "b"), "c"] - def test_preprocess_split_and_expand(self): + def test_preprocess_split_and_expand_not_adjoint(self): """Test that preprocess returns the correct tapes when splitting and expanding is needed.""" ops = [qml.Hadamard(0), NoMatOp(1), qml.RX([np.pi, np.pi / 2], wires=1)] @@ -564,6 +613,41 @@ def test_preprocess_split_and_expand(self): ] res_tapes, batch_fn, _ = preprocess(tapes) + expected_ops = [ + qml.Hadamard(0), + qml.PauliX(1), + qml.PauliY(1), + qml.RX([np.pi, np.pi / 2], wires=1), + ] + + assert len(res_tapes) == 2 + for i, t in enumerate(res_tapes): + for op, expected_op in zip(t.operations, expected_ops): + assert qml.equal(op, expected_op) + assert len(t.measurements) == 1 + if i == 0: + assert qml.equal(t.measurements[0], measurements[0]) + else: + assert qml.equal(t.measurements[0], measurements[1]) + + input = ([[1, 2], [3, 4]], [[5, 6], [7, 8]]) + assert np.array_equal(batch_fn(input), np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])) + + def test_preprocess_split_and_expand_adjoint(self): + """Test that preprocess returns the correct tapes when splitting and expanding + is needed.""" + ops = [qml.Hadamard(0), NoMatOp(1), qml.RX([np.pi, np.pi / 2], wires=1)] + # Need to specify grouping type to transform tape + measurements = [qml.expval(qml.PauliX(0)), qml.expval(qml.PauliZ(1))] + tapes = [ + QuantumScript(ops=ops, measurements=[measurements[0]]), + QuantumScript(ops=ops, measurements=[measurements[1]]), + ] + + execution_config = ExecutionConfig() + execution_config.gradient_method = "adjoint" + + res_tapes, batch_fn, _ = preprocess(tapes, execution_config=execution_config) expected_ops = [ [qml.Hadamard(0), qml.PauliX(1), qml.PauliY(1), qml.RX(np.pi, wires=1)], [qml.Hadamard(0), qml.PauliX(1), qml.PauliY(1), qml.RX(np.pi / 2, wires=1)], diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index c60ff83da70..1d8e07cf6e4 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -170,6 +170,80 @@ def test_tf_results_and_backprop(self): assert qml.math.allclose(grad1[0], -tf.sin(phi)) +class TestBroadcasting: + """Test that simulate works with broadcasted parameters""" + + def test_broadcasted_prep_state(self): + """Test that simulate works for state measurements + when the state prep has broadcasted parameters""" + x = np.array(1.2) + + ops = [qml.RY(x, wires=0), qml.CNOT(wires=[0, 1])] + measurements = [qml.expval(qml.PauliZ(i)) for i in range(2)] + prep = [qml.QubitStateVector(np.eye(4), wires=[0, 1])] + + qs = qml.tape.QuantumScript(ops, measurements, prep) + res = simulate(qs) + + assert isinstance(res, tuple) + assert len(res) == 2 + assert np.allclose(res[0], np.array([np.cos(x), np.cos(x), -np.cos(x), -np.cos(x)])) + assert np.allclose(res[1], np.array([np.cos(x), -np.cos(x), -np.cos(x), np.cos(x)])) + + def test_broadcasted_op_state(self): + """Test that simulate works for state measurements + when an operation has broadcasted parameters""" + x = np.array([0.8, 1.0, 1.2, 1.4]) + + ops = [qml.PauliX(wires=1), qml.RY(x, wires=0), qml.CNOT(wires=[0, 1])] + measurements = [qml.expval(qml.PauliZ(i)) for i in range(2)] + + qs = qml.tape.QuantumScript(ops, measurements) + res = simulate(qs) + + assert isinstance(res, tuple) + assert len(res) == 2 + assert np.allclose(res[0], np.cos(x)) + assert np.allclose(res[1], -np.cos(x)) + + def test_broadcasted_prep_sample(self): + """Test that simulate works for sample measurements + when the state prep has broadcasted parameters""" + x = np.array(1.2) + + ops = [qml.RY(x, wires=0), qml.CNOT(wires=[0, 1])] + measurements = [qml.expval(qml.PauliZ(i)) for i in range(2)] + prep = [qml.QubitStateVector(np.eye(4), wires=[0, 1])] + + qs = qml.tape.QuantumScript(ops, measurements, prep, shots=qml.measurements.Shots(10000)) + res = simulate(qs, rng=123) + + assert isinstance(res, tuple) + assert len(res) == 2 + assert np.allclose( + res[0], np.array([np.cos(x), np.cos(x), -np.cos(x), -np.cos(x)]), atol=0.05 + ) + assert np.allclose( + res[1], np.array([np.cos(x), -np.cos(x), -np.cos(x), np.cos(x)]), atol=0.05 + ) + + def test_broadcasted_op_sample(self): + """Test that simulate works for sample measurements + when an operation has broadcasted parameters""" + x = np.array([0.8, 1.0, 1.2, 1.4]) + + ops = [qml.PauliX(wires=1), qml.RY(x, wires=0), qml.CNOT(wires=[0, 1])] + measurements = [qml.expval(qml.PauliZ(i)) for i in range(2)] + + qs = qml.tape.QuantumScript(ops, measurements, shots=qml.measurements.Shots(10000)) + res = simulate(qs, rng=123) + + assert isinstance(res, tuple) + assert len(res) == 2 + assert np.allclose(res[0], np.cos(x), atol=0.05) + assert np.allclose(res[1], -np.cos(x), atol=0.05) + + class TestDebugger: """Tests that the debugger works for a simple circuit"""