diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 410f3beed9d..527261a78de 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -223,6 +223,9 @@ [(#5758)](https://github.com/PennyLaneAI/pennylane/pull/5758/) [(#5638)](https://github.com/PennyLaneAI/pennylane/pull/5638/) +* Device preprocess transforms now happen inside the ml boundary. + [(#5791)](https://github.com/PennyLaneAI/pennylane/pull/5791) + * `qml.qchem.molecular_dipole` function is added for calculating the dipole operator using "dhf" and "openfermion" backends. [(#5764)](https://github.com/PennyLaneAI/pennylane/pull/5764) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index c0a8a99f48d..6b1388802e9 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -502,7 +502,10 @@ def preprocess( transform_program.add_transform(validate_device_wires, self.wires, name=self.name) transform_program.add_transform( - mid_circuit_measurements, device=self, mcm_config=config.mcm_config + mid_circuit_measurements, + device=self, + mcm_config=config.mcm_config, + interface=config.interface, ) transform_program.add_transform( decompose, diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 23fed7614e6..eb465af72af 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -147,7 +147,10 @@ def validate_device_wires( @transform def mid_circuit_measurements( - tape: qml.tape.QuantumTape, device, mcm_config=MCMConfig() + tape: qml.tape.QuantumTape, + device, + mcm_config=MCMConfig(), + interface=None, ) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Provide the transform to handle mid-circuit measurements. @@ -162,7 +165,7 @@ def mid_circuit_measurements( mcm_method = "one-shot" if tape.shots else "deferred" if mcm_method == "one-shot": - return qml.dynamic_one_shot(tape) + return qml.dynamic_one_shot(tape, interface=interface) return qml.defer_measurements(tape, device=device) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 70cdf6fca28..1764b554ed3 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -28,14 +28,14 @@ class ConditionalTransformError(ValueError): """Error for using qml.cond incorrectly""" -class Conditional(SymbolicOp): +class Conditional(SymbolicOp, Operation): """A Conditional Operation. Unless you are a Pennylane plugin developer, **you should NOT directly use this class**, instead, use the :func:`qml.cond <.cond>` function. The ``Conditional`` class is a container class that defines an operation - that should by applied relative to a single measurement value. + that should be applied relative to a single measurement value. Support for executing ``Conditional`` operations is device-dependent. If a device doesn't support mid-circuit measurements natively, then the QNode @@ -54,13 +54,15 @@ def __init__(self, expr, then_op: Type[Operation], id=None): self.hyperparameters["meas_val"] = expr self._name = f"Conditional({then_op.name})" super().__init__(then_op, id=id) + if self.grad_recipe is None: + self.grad_recipe = [None] * self.num_params def label(self, decimals=None, base_label=None, cache=None): return self.base.label(decimals=decimals, base_label=base_label, cache=cache) @property def meas_val(self): - "the measurement outcome value to consider from `expr` argument" + """the measurement outcome value to consider from `expr` argument""" return self.hyperparameters["meas_val"] @property diff --git a/pennylane/transforms/core/transform_program.py b/pennylane/transforms/core/transform_program.py index 31eac21e420..47d53b7f53e 100644 --- a/pennylane/transforms/core/transform_program.py +++ b/pennylane/transforms/core/transform_program.py @@ -17,8 +17,6 @@ from functools import partial from typing import Callable, List, Optional, Sequence, Tuple, Union -import numpy as np - import pennylane as qml from pennylane.tape import QuantumTape from pennylane.typing import Result, ResultBatch @@ -354,24 +352,32 @@ def set_classical_component(self, qnode, args, kwargs): self._set_all_classical_jacobians(qnode, args, kwargs, argnums) self._set_all_argnums(qnode, args, kwargs, argnums) - def prune_dynamic_transform(self): - """Ensure a single ``dynamic_one_shot`` transform is applied.""" - trans_type = np.zeros(len(self._transform_program), dtype=np.int32) - for i, t in enumerate(self._transform_program): - if "dynamic_one_shot" in str(t): - trans_type[i] = 1 - if "mid_circuit_measurements" in str(t): - trans_type[i] = 2 - if sum(trans_type) < 2: - return - keep = 2 if 2 in trans_type else 1 + def prune_dynamic_transform(self, type_to_keep=1): + """Ensures that only one or none ``dynamic_one_shot`` is applied. + + Args: + type_to_keep (int): The type of the dynamic transform to keep. 0: keep none, + 1: dynamic_one_shot or mid_circuit_measurements, 2: only mid_circuit_measurements. + + Returns: + bool: ``True`` if a dynamic transform was found, ``False`` otherwise. + + """ + + i = len(self._transform_program) - 1 found = False - for i, ttype in enumerate(reversed(trans_type)): - if not found and ttype == keep: + while i >= 0: + t = self._transform_program[i] + if "mid_circuit_measurements" in str(t) and type_to_keep > 0: + type_to_keep = 0 # keep this and do not keep the rest + found = True + elif "dynamic_one_shot" in str(t) and type_to_keep == 1: + type_to_keep = 0 # keep this and do not keep the rest found = True - continue - if found and ttype in [1, 2]: - self._transform_program.pop(len(self._transform_program) - 1 - i) + elif "dynamic_one_shot" in str(t) or "mid_circuit_measurements" in str(t): + self._transform_program.pop(i) + i -= 1 + return found def _set_all_classical_jacobians( self, qnode, args, kwargs, argnums diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index b05c1d583b0..cfaef28a225 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -124,6 +124,8 @@ def func(x, y): aux_tapes = [init_auxiliary_tape(t) for t in tapes] + interface = kwargs.get("interface", None) + def reshape_data(array): return qml.math.squeeze(qml.math.vstack(array)) @@ -158,7 +160,7 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None): results = [ reshape_data(tuple(res[i] for res in results)) for i, _ in enumerate(results[0]) ] - return parse_native_mid_circuit_measurements(tape, aux_tapes, results) + return parse_native_mid_circuit_measurements(tape, aux_tapes, results, interface=interface) return aux_tapes, processing_fn @@ -221,7 +223,10 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): # pylint: disable=too-many-branches,too-many-statements def parse_native_mid_circuit_measurements( - circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results: TensorLike + circuit: qml.tape.QuantumScript, + aux_tapes: qml.tape.QuantumScript, + results: TensorLike, + interface=None, ): """Combines, gathers and normalizes the results of native mid-circuit measurement runs. @@ -241,7 +246,7 @@ def measurement_with_no_shots(measurement): else np.nan ) - interface = qml.math.get_deep_interface(circuit.data) + interface = interface or qml.math.get_deep_interface(circuit.data) interface = "numpy" if interface == "builtins" else interface active_qjit = qml.compiler.active() diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 13fd2ce7cea..b020cb7c7f3 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -253,13 +253,19 @@ def device_expansion_function(tape): # pylint: disable=function-redefined def _make_inner_execute( - device, override_shots, cache, expand_fn=None, execution_config=None, numpy_only=True + device, + override_shots, + cache, + inner_transform, + expand_fn=None, + execution_config=None, + numpy_only=True, ) -> Callable: """Construct the function that will execute the tapes inside the ml framework registration for the 1st order derivatives. Steps in between the ml framework execution and the device are: - - device expansion (old device) + - device expansion (old device) or device preprocessing (new device) - conversion to numpy - caching @@ -287,11 +293,12 @@ def inner_execute(tapes: Sequence[QuantumTape], **_) -> ResultBatch: Closure Variables: expand_fn (Callable[[QuantumTape], QuantumTape]): A device preprocessing step - numpy_only (bool): whether or not to convert the data to numpy or leave as is + numpy_only (bool): whether to convert the data to numpy or leave as is device_execution (Callable[[Sequence[QuantumTape]], ResultBatch]) cache (None | MutableMapping): The cache to use. If ``None``, caching will not occur. """ - transform_program = qml.transforms.core.TransformProgram() + + transform_program = qml.transforms.core.TransformProgram(inner_transform) if numpy_only: transform_program.add_transform(qml.transforms.convert_to_numpy_parameters) @@ -299,11 +306,11 @@ def inner_execute(tapes: Sequence[QuantumTape], **_) -> ResultBatch: if cache is not None: transform_program.add_transform(_cache_transform, cache=cache) + transformed_tapes, transform_post_processing = transform_program(tapes) + # TODO: Apply expand_fn() as transform. if expand_fn: - tapes = tuple(expand_fn(t) for t in tapes) - - transformed_tapes, transform_post_processing = transform_program(tapes) + transformed_tapes = tuple(expand_fn(t) for t in transformed_tapes) if transformed_tapes: results = device_execution(transformed_tapes) @@ -407,6 +414,7 @@ def execute( gradient_fn: Optional[Union[Callable, str]] = None, interface="auto", transform_program=None, + inner_transform=None, config=None, grad_on_execution="best", gradient_kwargs=None, @@ -435,6 +443,7 @@ def execute( This affects the types of parameters that can exist on the input tapes. Available options include ``autograd``, ``torch``, ``tf``, ``jax`` and ``auto``. transform_program(.TransformProgram): A transform program to be applied to the initial tape. + inner_transform (.TransformProgram): A transform program to be applied to the tapes in inner execution, inside the ml interface. config (qml.devices.ExecutionConfig): A datastructure describing the parameters needed to fully describe the execution. grad_on_execution (bool, str): Whether the gradients should be computed on the execution or not. Only applies if the device is queried for the gradient; gradient transform @@ -587,11 +596,10 @@ def cost_fn(params, x): ) config.mcm_config.postselect_mode = "fill-shots" - if transform_program is None: - if isinstance(device, qml.devices.Device): - transform_program = device.preprocess(config)[0] - else: - transform_program = qml.transforms.core.TransformProgram() + is_gradient_transform = isinstance(gradient_fn, qml.transforms.core.TransformDispatcher) + transform_program, inner_transform = _make_transform_programs( + device, config, inner_transform, transform_program, is_gradient_transform + ) # If caching is desired but an explicit cache is not provided, use an ``LRUCache``. if cache is True: @@ -617,6 +625,7 @@ def cost_fn(params, x): device, override_shots, cache, + inner_transform, expand_fn, config, numpy_only=not device_supports_interface_data, @@ -754,7 +763,9 @@ def device_execute_and_gradients(internal_tapes, **gradient_kwargs): else: # need to override to have no cache - inner_execute = _make_inner_execute(device, override_shots, cache=None) + inner_execute = _make_inner_execute( + device, override_shots, cache=None, inner_transform=inner_transform + ) def inner_execute_with_empty_jac(tapes, **_): return (inner_execute(tapes), []) @@ -830,6 +841,35 @@ def device_gradient_fn(inner_tapes, **gradient_kwargs): return post_processing(results) +def _make_transform_programs( + device, config, inner_transform, transform_program, is_gradient_transform +): + """helper function to make the transform programs.""" + + if isinstance(device, qml.devices.Device): + + # If gradient_fn is a gradient transform, device preprocessing should happen in + # inner execute (inside the ml boundary). + if is_gradient_transform: + if inner_transform is None: + inner_transform = device.preprocess(config)[0] + if transform_program is None: + transform_program = qml.transforms.core.TransformProgram() + else: + if inner_transform is None: + inner_transform = qml.transforms.core.TransformProgram() + if transform_program is None: + transform_program = device.preprocess(config)[0] + + else: + if transform_program is None: + transform_program = qml.transforms.core.TransformProgram() + if inner_transform is None: + inner_transform = qml.transforms.core.TransformProgram() + + return transform_program, inner_transform + + def _get_execution_config( gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ): diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 42a0c39f09a..f07dbdb0810 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1053,14 +1053,19 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml if mcm_config.mcm_method == "single-branch-statistics": raise ValueError("Cannot use mcm_method='single-branch-statistics' without qml.qjit.") - # Add the device program to the QNode program + full_transform_program = qml.transforms.core.TransformProgram(self.transform_program) + inner_transform_program = qml.transforms.core.TransformProgram() + config = None + if isinstance(self.device, qml.devices.Device): + config = _make_execution_config(self, self.gradient_fn) device_transform_program, config = self.device.preprocess(execution_config=config) - full_transform_program = self.transform_program + device_transform_program - else: - config = None - full_transform_program = qml.transforms.core.TransformProgram(self.transform_program) + + if config.use_device_gradient: + full_transform_program += device_transform_program + else: + inner_transform_program += device_transform_program has_mcm_support = ( any(isinstance(op, MidMeasureMP) for op in self._tape) @@ -1068,14 +1073,15 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml and self.device.capabilities().get("supports_mid_measure", False) ) if has_mcm_support: - full_transform_program.add_transform( + inner_transform_program.add_transform( qml.devices.preprocess.mid_circuit_measurements, device=self.device, mcm_config=mcm_config, + interface=self.interface, ) override_shots = 1 elif hasattr(self.device, "capabilities"): - full_transform_program.add_transform( + inner_transform_program.add_transform( qml.defer_measurements, device=self.device, ) @@ -1086,9 +1092,10 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml qml.transform(self.gradient_fn.expand_transform), **self.gradient_kwargs, ) + # Calculate the classical jacobians if necessary full_transform_program.set_classical_component(self, args, kwargs) - full_transform_program.prune_dynamic_transform() + _prune_dynamic_transform(full_transform_program, inner_transform_program) # pylint: disable=unexpected-keyword-arg res = qml.execute( @@ -1097,6 +1104,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml gradient_fn=self.gradient_fn, interface=self.interface, transform_program=full_transform_program, + inner_transform=inner_transform_program, config=config, gradient_kwargs=self.gradient_kwargs, override_shots=override_shots, @@ -1157,3 +1165,30 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result: qnode = lambda device, **kwargs: functools.partial(QNode, device=device, **kwargs) qnode.__doc__ = QNode.__doc__ qnode.__signature__ = inspect.signature(QNode) + + +def _prune_dynamic_transform(outer_transform, inner_transform): + """Ensure a single ``dynamic_one_shot`` transform is applied. + + Sometimes device preprocess contains a ``mid_circuit_measurements`` transform, which will + be added to the inner transform program. If the user then applies a ``dynamic_one_shot`` + manually, it will duplicate the ``mid_circuit_measurements`` transform. This function ensures + that there is only one ``dynamic_one_shot`` transform in the outer and inner transform + programs combined. + + """ + + all_transforms = outer_transform + inner_transform + type_to_keep = 0 + if any("mid_circuit_measurements" in str(t) for t in all_transforms): + type_to_keep = 2 + elif any("dynamic_one_shot" in str(t) for t in all_transforms): + type_to_keep = 1 + + if type_to_keep == 0: + return + + dynamic_transform_found = inner_transform.prune_dynamic_transform(type_to_keep) + if dynamic_transform_found: + type_to_keep = 0 + outer_transform.prune_dynamic_transform(type_to_keep) diff --git a/tests/docs/test_supported_confs.py b/tests/docs/test_supported_confs.py index bebba067858..2859ffea0b6 100644 --- a/tests/docs/test_supported_confs.py +++ b/tests/docs/test_supported_confs.py @@ -412,7 +412,7 @@ def test_all_paramshift_state(self, interface, return_type, shots, wire_specs): # with pytest.raises(ValueError, match=msg): circuit = get_qnode(interface, "parameter-shift", return_type, shots, wire_specs) x = get_variable(interface, wire_specs, complex=complex) - if shots is not None: + if shots is not None and interface != "jax": with pytest.raises(qml.DeviceError, match="not accepted with finite shots"): compute_gradient(x, interface, circuit, return_type, complex=complex) else: @@ -540,7 +540,7 @@ def test_all_hadamard_nonstate_non_var( circuit = get_qnode(interface, diff_method, return_type, shots, wire_specs) x = get_variable(interface, wire_specs) if return_type in (VnEntropy, MutualInfo): - if shots: + if shots and interface != "jax": err_cls = qml.DeviceError msg = "not accepted with finite shots" else: diff --git a/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py index cb1306d7b59..065403f9ff3 100644 --- a/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py +++ b/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Integration tests for using the JAX-JIT interface with a QNode""" +import copy + # pylint: disable=too-many-arguments,too-few-public-methods from functools import partial @@ -1374,6 +1376,7 @@ def test_state(self, dev, diff_method, grad_on_execution, device_vjp, interface, x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) if not dev.wires: + dev = copy.copy(dev) dev._wires = qml.wires.Wires([0, 1]) # pylint:disable=protected-access @qnode( diff --git a/tests/test_qnode.py b/tests/test_qnode.py index dbbdfa48ed3..b11ca5db1e5 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -29,6 +29,7 @@ from pennylane import numpy as pnp from pennylane import qnode from pennylane.tape import QuantumScript +from pennylane.workflow.qnode import _prune_dynamic_transform def dummyfunc(): @@ -2026,3 +2027,48 @@ def circuit(x): circuit(qml.numpy.array(0.1)) assert circuit.interface == "auto" + + +def test_prune_dynamic_transform(): + """Tests that the helper function prune dynamic transform works.""" + + program1 = qml.transforms.core.TransformProgram( + [ + qml.transforms.dynamic_one_shot, + qml.transforms.sum_expand, + qml.transforms.dynamic_one_shot, + ] + ) + program2 = qml.transforms.core.TransformProgram( + [ + qml.transforms.dynamic_one_shot, + qml.transforms.sum_expand, + ] + ) + + _prune_dynamic_transform(program1, program2) + assert len(program1) == 1 + assert len(program2) == 2 + + +def test_prune_dynamic_transform_with_mcm(): + """Tests that the helper function prune dynamic transform works with mcm""" + + program1 = qml.transforms.core.TransformProgram( + [ + qml.transforms.dynamic_one_shot, + qml.transforms.sum_expand, + qml.devices.preprocess.mid_circuit_measurements, + ] + ) + program2 = qml.transforms.core.TransformProgram( + [ + qml.transforms.dynamic_one_shot, + qml.transforms.sum_expand, + ] + ) + + _prune_dynamic_transform(program1, program2) + assert len(program1) == 2 + assert qml.devices.preprocess.mid_circuit_measurements in program1 + assert len(program2) == 1 diff --git a/tests/transforms/core/test_transform_program.py b/tests/transforms/core/test_transform_program.py index d3c9747d122..30fdc6266cf 100644 --- a/tests/transforms/core/test_transform_program.py +++ b/tests/transforms/core/test_transform_program.py @@ -356,6 +356,13 @@ def test_empty_program(self): ): program.get_last() + def test_get_last(self): + """Tests the get_last method""" + program = TransformProgram() + program.add_transform(transform(first_valid_transform)) + program.add_transform(transform(second_valid_transform)) + assert program.get_last() == TransformContainer(transform=second_valid_transform) + def test_push_back(self): """Test to push back multiple transforms into a program and also the different methods of a program.""" transform_program = TransformProgram() diff --git a/tests/workflow/test_construct_batch.py b/tests/workflow/test_construct_batch.py index e80aaee331b..44e345e552a 100644 --- a/tests/workflow/test_construct_batch.py +++ b/tests/workflow/test_construct_batch.py @@ -106,7 +106,8 @@ def circuit(): assert p_dev == p_default assert p_none == p_dev assert len(p_dev) == 9 - assert p_dev == p_grad + dev.preprocess()[0] + config = qml.devices.ExecutionConfig(interface=getattr(circuit, "interface", None)) + assert p_dev == p_grad + dev.preprocess(config)[0] # slicing p_sliced = get_transform_program(circuit, slice(2, 7, 2)) @@ -142,7 +143,9 @@ def circuit(x): assert len(full_prog) == 13 config = qml.devices.ExecutionConfig( - gradient_method="adjoint", use_device_jacobian_product=False + interface=getattr(circuit, "interface", None), + gradient_method="adjoint", + use_device_jacobian_product=False, ) dev_program = dev.preprocess(config)[0] @@ -194,7 +197,8 @@ def circuit(): assert grad_program[2].transform == qml.gradients.param_shift.expand_transform dev_program = get_transform_program(circuit, level="device") - assert len(dev_program) == 3 + len(circuit.device.preprocess()[0]) # currently 8 + config = qml.devices.ExecutionConfig(interface=getattr(circuit, "interface", None)) + assert len(dev_program) == 3 + len(circuit.device.preprocess(config)[0]) # currently 8 assert qml.metric_tensor not in dev_program full = get_transform_program(circuit)