Skip to content

Commit

Permalink
Move device preprocessing to inner execute (#5791)
Browse files Browse the repository at this point in the history
**Context:**
If there is a gradient transform, we want device preprocessing to happen
within the ml boundary, i.e., in `inner_execute`, to avoid applying the
gradient transform on transformed tapes produced by device
preprocessing.

**Description of the Change:**
1. Adds a `inner_transform` keyword to `qml.execute`
2. Update `QNode._execute_component` and `qml.execute` to add device
preprocessing to the inner transform program if the `gradient_fn` is a
gradient transform.

**Related GitHub Issues:**
[sc-59107]

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
3 people committed Jun 12, 2024
1 parent f0bc612 commit 4eb803a
Show file tree
Hide file tree
Showing 13 changed files with 210 additions and 53 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@
[(#5758)](https://github.com/PennyLaneAI/pennylane/pull/5758/)
[(#5638)](https://github.com/PennyLaneAI/pennylane/pull/5638/)

* Device preprocess transforms now happen inside the ml boundary.
[(#5791)](https://github.com/PennyLaneAI/pennylane/pull/5791)

* `qml.qchem.molecular_dipole` function is added for calculating the dipole operator using "dhf" and "openfermion" backends.
[(#5764)](https://github.com/PennyLaneAI/pennylane/pull/5764)

Expand Down
5 changes: 4 additions & 1 deletion pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,10 @@ def preprocess(

transform_program.add_transform(validate_device_wires, self.wires, name=self.name)
transform_program.add_transform(
mid_circuit_measurements, device=self, mcm_config=config.mcm_config
mid_circuit_measurements,
device=self,
mcm_config=config.mcm_config,
interface=config.interface,
)
transform_program.add_transform(
decompose,
Expand Down
7 changes: 5 additions & 2 deletions pennylane/devices/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def validate_device_wires(

@transform
def mid_circuit_measurements(
tape: qml.tape.QuantumTape, device, mcm_config=MCMConfig()
tape: qml.tape.QuantumTape,
device,
mcm_config=MCMConfig(),
interface=None,
) -> tuple[Sequence[qml.tape.QuantumTape], Callable]:
"""Provide the transform to handle mid-circuit measurements.
Expand All @@ -162,7 +165,7 @@ def mid_circuit_measurements(
mcm_method = "one-shot" if tape.shots else "deferred"

if mcm_method == "one-shot":
return qml.dynamic_one_shot(tape)
return qml.dynamic_one_shot(tape, interface=interface)
return qml.defer_measurements(tape, device=device)


Expand Down
8 changes: 5 additions & 3 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ class ConditionalTransformError(ValueError):
"""Error for using qml.cond incorrectly"""


class Conditional(SymbolicOp):
class Conditional(SymbolicOp, Operation):
"""A Conditional Operation.
Unless you are a Pennylane plugin developer, **you should NOT directly use this class**,
instead, use the :func:`qml.cond <.cond>` function.
The ``Conditional`` class is a container class that defines an operation
that should by applied relative to a single measurement value.
that should be applied relative to a single measurement value.
Support for executing ``Conditional`` operations is device-dependent. If a
device doesn't support mid-circuit measurements natively, then the QNode
Expand All @@ -54,13 +54,15 @@ def __init__(self, expr, then_op: Type[Operation], id=None):
self.hyperparameters["meas_val"] = expr
self._name = f"Conditional({then_op.name})"
super().__init__(then_op, id=id)
if self.grad_recipe is None:
self.grad_recipe = [None] * self.num_params

def label(self, decimals=None, base_label=None, cache=None):
return self.base.label(decimals=decimals, base_label=base_label, cache=cache)

@property
def meas_val(self):
"the measurement outcome value to consider from `expr` argument"
"""the measurement outcome value to consider from `expr` argument"""
return self.hyperparameters["meas_val"]

@property
Expand Down
42 changes: 24 additions & 18 deletions pennylane/transforms/core/transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from functools import partial
from typing import Callable, List, Optional, Sequence, Tuple, Union

import numpy as np

import pennylane as qml
from pennylane.tape import QuantumTape
from pennylane.typing import Result, ResultBatch
Expand Down Expand Up @@ -354,24 +352,32 @@ def set_classical_component(self, qnode, args, kwargs):
self._set_all_classical_jacobians(qnode, args, kwargs, argnums)
self._set_all_argnums(qnode, args, kwargs, argnums)

def prune_dynamic_transform(self):
"""Ensure a single ``dynamic_one_shot`` transform is applied."""
trans_type = np.zeros(len(self._transform_program), dtype=np.int32)
for i, t in enumerate(self._transform_program):
if "dynamic_one_shot" in str(t):
trans_type[i] = 1
if "mid_circuit_measurements" in str(t):
trans_type[i] = 2
if sum(trans_type) < 2:
return
keep = 2 if 2 in trans_type else 1
def prune_dynamic_transform(self, type_to_keep=1):
"""Ensures that only one or none ``dynamic_one_shot`` is applied.
Args:
type_to_keep (int): The type of the dynamic transform to keep. 0: keep none,
1: dynamic_one_shot or mid_circuit_measurements, 2: only mid_circuit_measurements.
Returns:
bool: ``True`` if a dynamic transform was found, ``False`` otherwise.
"""

i = len(self._transform_program) - 1
found = False
for i, ttype in enumerate(reversed(trans_type)):
if not found and ttype == keep:
while i >= 0:
t = self._transform_program[i]
if "mid_circuit_measurements" in str(t) and type_to_keep > 0:
type_to_keep = 0 # keep this and do not keep the rest
found = True
elif "dynamic_one_shot" in str(t) and type_to_keep == 1:
type_to_keep = 0 # keep this and do not keep the rest
found = True
continue
if found and ttype in [1, 2]:
self._transform_program.pop(len(self._transform_program) - 1 - i)
elif "dynamic_one_shot" in str(t) or "mid_circuit_measurements" in str(t):
self._transform_program.pop(i)
i -= 1
return found

def _set_all_classical_jacobians(
self, qnode, args, kwargs, argnums
Expand Down
11 changes: 8 additions & 3 deletions pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def func(x, y):

aux_tapes = [init_auxiliary_tape(t) for t in tapes]

interface = kwargs.get("interface", None)

def reshape_data(array):
return qml.math.squeeze(qml.math.vstack(array))

Expand Down Expand Up @@ -158,7 +160,7 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None):
results = [
reshape_data(tuple(res[i] for res in results)) for i, _ in enumerate(results[0])
]
return parse_native_mid_circuit_measurements(tape, aux_tapes, results)
return parse_native_mid_circuit_measurements(tape, aux_tapes, results, interface=interface)

return aux_tapes, processing_fn

Expand Down Expand Up @@ -221,7 +223,10 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript):

# pylint: disable=too-many-branches,too-many-statements
def parse_native_mid_circuit_measurements(
circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results: TensorLike
circuit: qml.tape.QuantumScript,
aux_tapes: qml.tape.QuantumScript,
results: TensorLike,
interface=None,
):
"""Combines, gathers and normalizes the results of native mid-circuit measurement runs.
Expand All @@ -241,7 +246,7 @@ def measurement_with_no_shots(measurement):
else np.nan
)

interface = qml.math.get_deep_interface(circuit.data)
interface = interface or qml.math.get_deep_interface(circuit.data)
interface = "numpy" if interface == "builtins" else interface
active_qjit = qml.compiler.active()

Expand Down
66 changes: 53 additions & 13 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,19 @@ def device_expansion_function(tape): # pylint: disable=function-redefined


def _make_inner_execute(
device, override_shots, cache, expand_fn=None, execution_config=None, numpy_only=True
device,
override_shots,
cache,
inner_transform,
expand_fn=None,
execution_config=None,
numpy_only=True,
) -> Callable:
"""Construct the function that will execute the tapes inside the ml framework registration
for the 1st order derivatives.
Steps in between the ml framework execution and the device are:
- device expansion (old device)
- device expansion (old device) or device preprocessing (new device)
- conversion to numpy
- caching
Expand Down Expand Up @@ -287,23 +293,24 @@ def inner_execute(tapes: Sequence[QuantumTape], **_) -> ResultBatch:
Closure Variables:
expand_fn (Callable[[QuantumTape], QuantumTape]): A device preprocessing step
numpy_only (bool): whether or not to convert the data to numpy or leave as is
numpy_only (bool): whether to convert the data to numpy or leave as is
device_execution (Callable[[Sequence[QuantumTape]], ResultBatch])
cache (None | MutableMapping): The cache to use. If ``None``, caching will not occur.
"""
transform_program = qml.transforms.core.TransformProgram()

transform_program = qml.transforms.core.TransformProgram(inner_transform)

if numpy_only:
transform_program.add_transform(qml.transforms.convert_to_numpy_parameters)

if cache is not None:
transform_program.add_transform(_cache_transform, cache=cache)

transformed_tapes, transform_post_processing = transform_program(tapes)

# TODO: Apply expand_fn() as transform.
if expand_fn:
tapes = tuple(expand_fn(t) for t in tapes)

transformed_tapes, transform_post_processing = transform_program(tapes)
transformed_tapes = tuple(expand_fn(t) for t in transformed_tapes)

if transformed_tapes:
results = device_execution(transformed_tapes)
Expand Down Expand Up @@ -407,6 +414,7 @@ def execute(
gradient_fn: Optional[Union[Callable, str]] = None,
interface="auto",
transform_program=None,
inner_transform=None,
config=None,
grad_on_execution="best",
gradient_kwargs=None,
Expand Down Expand Up @@ -435,6 +443,7 @@ def execute(
This affects the types of parameters that can exist on the input tapes.
Available options include ``autograd``, ``torch``, ``tf``, ``jax`` and ``auto``.
transform_program(.TransformProgram): A transform program to be applied to the initial tape.
inner_transform (.TransformProgram): A transform program to be applied to the tapes in inner execution, inside the ml interface.
config (qml.devices.ExecutionConfig): A datastructure describing the parameters needed to fully describe the execution.
grad_on_execution (bool, str): Whether the gradients should be computed on the execution or not. Only applies
if the device is queried for the gradient; gradient transform
Expand Down Expand Up @@ -587,11 +596,10 @@ def cost_fn(params, x):
)
config.mcm_config.postselect_mode = "fill-shots"

if transform_program is None:
if isinstance(device, qml.devices.Device):
transform_program = device.preprocess(config)[0]
else:
transform_program = qml.transforms.core.TransformProgram()
is_gradient_transform = isinstance(gradient_fn, qml.transforms.core.TransformDispatcher)
transform_program, inner_transform = _make_transform_programs(
device, config, inner_transform, transform_program, is_gradient_transform
)

# If caching is desired but an explicit cache is not provided, use an ``LRUCache``.
if cache is True:
Expand All @@ -617,6 +625,7 @@ def cost_fn(params, x):
device,
override_shots,
cache,
inner_transform,
expand_fn,
config,
numpy_only=not device_supports_interface_data,
Expand Down Expand Up @@ -754,7 +763,9 @@ def device_execute_and_gradients(internal_tapes, **gradient_kwargs):

else:
# need to override to have no cache
inner_execute = _make_inner_execute(device, override_shots, cache=None)
inner_execute = _make_inner_execute(
device, override_shots, cache=None, inner_transform=inner_transform
)

def inner_execute_with_empty_jac(tapes, **_):
return (inner_execute(tapes), [])
Expand Down Expand Up @@ -830,6 +841,35 @@ def device_gradient_fn(inner_tapes, **gradient_kwargs):
return post_processing(results)


def _make_transform_programs(
device, config, inner_transform, transform_program, is_gradient_transform
):
"""helper function to make the transform programs."""

if isinstance(device, qml.devices.Device):

# If gradient_fn is a gradient transform, device preprocessing should happen in
# inner execute (inside the ml boundary).
if is_gradient_transform:
if inner_transform is None:
inner_transform = device.preprocess(config)[0]
if transform_program is None:
transform_program = qml.transforms.core.TransformProgram()
else:
if inner_transform is None:
inner_transform = qml.transforms.core.TransformProgram()
if transform_program is None:
transform_program = device.preprocess(config)[0]

else:
if transform_program is None:
transform_program = qml.transforms.core.TransformProgram()
if inner_transform is None:
inner_transform = qml.transforms.core.TransformProgram()

return transform_program, inner_transform


def _get_execution_config(
gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config
):
Expand Down
Loading

0 comments on commit 4eb803a

Please sign in to comment.