Skip to content

Commit

Permalink
Clean up how interface is handled in QNode and qml.execute (#6225)
Browse files Browse the repository at this point in the history
Regarding `numpy` and `autograd`:
- When the parameters are of the `numpy` interface, internally treat it
as `interface=None`.
- Does not change the behaviour of treating user specified
`interface="numpy"` as using autograd.

Regarding interfaces in general:
- The set of canonical interface names in `INTERFACE_MAP` is expanded to
include more specific names such as `jax-jit`, and `tf-autograph`.
`_convert_to_interfaces` in `qnode.py` uses a separate
`interface_conversion_map` to further map the specific interfaces to
their corresponding general interface names that can be passed to the
`like` argument of `qml.math.asarray` (e.g. "tf" to "tensorflow",
"jax-jit" to "jax").
- In `QNode` and `qml.execute`, every time we get an interface from user
input or `qml.math.get_interface`, we map it to a canonical interface
name using `INTERFACE_MAP`. Aside from these two scenarios, we assume
that the interface name is one of the canonical interface names
everywhere else. `QNode.interface` is now assumed to be one of the
canonical interface names.
- User input of `interface=None` gets mapped to `numpy` immediately.
Internally, `QNode.interface` will never be `None`. It'll be `numpy` for
having no interface.
- If `qml.math.get_interface` returns `numpy`, we do not map it to
anything. We keep `numpy`.

Collateral bug fix included as well:
- Fixes a bug where a circuit of the `autograd` interfaces sometimes
returns results that are not `autograd`.
- Adds `compute_sparse_matrix` to `Hermitian`

[sc-73144]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
astralcai and albi3ro committed Sep 16, 2024
1 parent 94f067a commit 228fdaf
Show file tree
Hide file tree
Showing 19 changed files with 221 additions and 144 deletions.
19 changes: 11 additions & 8 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,14 @@
[(#6061)](https://github.com/PennyLaneAI/pennylane/pull/6061)

* `qml.qchem.excitations` now optionally returns fermionic operators.
[(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171)
[(#6171)](https://github.com/PennyLaneAI/pennylane/pull/6171)

* The `diagonalize_measurements` transform now uses a more efficient method of diagonalization
when possible, based on the `pauli_rep` of the relevant observables.
[#6113](https://github.com/PennyLaneAI/pennylane/pull/6113/)

<h4>Capturing and representing hybrid programs</h4>

* Differentiation of hybrid programs via `qml.grad` can now be captured into plxpr.
When evaluating a captured `qml.grad` instruction, it will dispatch to `jax.grad`,
which differs from the Autograd implementation of `qml.grad` itself.
[(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120)
* The `Hermitian` operator now has a `compute_sparse_matrix` implementation.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

<h4>Capturing and representing hybrid programs</h4>

Expand Down Expand Up @@ -120,12 +116,19 @@
* The ``qml.FABLE`` template now returns the correct value when JIT is enabled.
[(#6263)](https://github.com/PennyLaneAI/pennylane/pull/6263)

* <h3>Contributors ✍️</h3>
* Fixes a bug where a circuit using the `autograd` interface sometimes returns nested values that are not of the `autograd` interface.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

* Fixes a bug where a simple circuit with no parameters or only builtin/numpy arrays as parameters returns autograd tensors.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Guillermo Alonso,
Utkarsh Azad,
Astral Cai,
Lillian M. A. Frederiksen,
Pietropaolo Frisoni,
Emiliano Godinez,
Expand Down
6 changes: 3 additions & 3 deletions pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Optional, Union

from pennylane.workflow import SUPPORTED_INTERFACES
from pennylane.workflow import SUPPORTED_INTERFACE_NAMES


@dataclass
Expand Down Expand Up @@ -110,9 +110,9 @@ def __post_init__(self):
Note that this hook is automatically called after init via the dataclass integration.
"""
if self.interface not in SUPPORTED_INTERFACES:
if self.interface not in SUPPORTED_INTERFACE_NAMES:
raise ValueError(
f"Unknown interface. interface must be in {SUPPORTED_INTERFACES}, got {self.interface} instead."
f"Unknown interface. interface must be in {SUPPORTED_INTERFACE_NAMES}, got {self.interface} instead."
)

if self.grad_on_execution not in {True, False, None}:
Expand Down
10 changes: 5 additions & 5 deletions pennylane/devices/legacy_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pennylane as qml
from pennylane.measurements import MidMeasureMP, Shots
from pennylane.transforms.core.transform_program import TransformProgram
from pennylane.workflow.execution import INTERFACE_MAP

from .device_api import Device
from .execution_config import DefaultExecutionConfig
Expand Down Expand Up @@ -322,25 +323,24 @@ def _validate_backprop_method(self, tape):
return False
params = tape.get_parameters(trainable_only=False)
interface = qml.math.get_interface(*params)
if interface != "numpy":
interface = INTERFACE_MAP.get(interface, interface)

if tape and any(isinstance(m.obs, qml.SparseHamiltonian) for m in tape.measurements):
return False
if interface == "numpy":
interface = None
mapped_interface = qml.workflow.execution.INTERFACE_MAP.get(interface, interface)

# determine if the device supports backpropagation
backprop_interface = self._device.capabilities().get("passthru_interface", None)

if backprop_interface is not None:
# device supports backpropagation natively
return mapped_interface in [backprop_interface, "Numpy"]
return interface in [backprop_interface, "numpy"]
# determine if the device has any child devices that support backpropagation
backprop_devices = self._device.capabilities().get("passthru_devices", None)

if backprop_devices is None:
return False
return mapped_interface in backprop_devices or mapped_interface == "Numpy"
return interface in backprop_devices or interface == "numpy"

def _validate_adjoint_method(self, tape):
# The conditions below provide a minimal set of requirements that we can likely improve upon in
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def _(original_measurement: ExpectationMP, measures): # pylint: disable=unused-
for v in measures.values():
if not v[0] or v[1] is tuple():
continue
cum_value += v[0] * v[1]
cum_value += qml.math.multiply(v[0], v[1])
total_counts += v[0]
return cum_value / total_counts

Expand All @@ -935,7 +935,7 @@ def _(original_measurement: ProbabilityMP, measures): # pylint: disable=unused-
for v in measures.values():
if not v[0] or v[1] is tuple():
continue
cum_value += v[0] * v[1]
cum_value += qml.math.multiply(v[0], v[1])
total_counts += v[0]
return cum_value / total_counts

Expand Down
4 changes: 4 additions & 0 deletions pennylane/ops/qubit/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def compute_matrix(A: TensorLike) -> TensorLike: # pylint: disable=arguments-di
Hermitian._validate_input(A)
return A

@staticmethod
def compute_sparse_matrix(A) -> csr_matrix: # pylint: disable=arguments-differ
return csr_matrix(Hermitian.compute_matrix(A))

@property
def eigendecomposition(self) -> dict[str, TensorLike]:
"""Return the eigendecomposition of the matrix specified by the Hermitian observable.
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@
"""
from .construct_batch import construct_batch, get_transform_program
from .execution import INTERFACE_MAP, SUPPORTED_INTERFACES, execute
from .execution import INTERFACE_MAP, SUPPORTED_INTERFACE_NAMES, execute
from .qnode import QNode, qnode
from .set_shots import set_shots
75 changes: 43 additions & 32 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,9 @@
"autograd",
"numpy",
"torch",
"pytorch",
"jax",
"jax-python",
"jax-jit",
"tf",
"tensorflow",
}

SupportedInterfaceUserInput = Literal[
Expand All @@ -78,30 +75,29 @@
]

_mapping_output = (
"Numpy",
"numpy",
"auto",
"autograd",
"autograd",
"numpy",
"jax",
"jax",
"jax-jit",
"jax",
"jax",
"torch",
"torch",
"tf",
"tf",
"tf",
"tf",
"tf-autograph",
"tf-autograph",
)

INTERFACE_MAP = dict(zip(get_args(SupportedInterfaceUserInput), _mapping_output))
"""dict[str, str]: maps an allowed interface specification to its canonical name."""

#: list[str]: allowed interface strings
SUPPORTED_INTERFACES = list(INTERFACE_MAP)
SUPPORTED_INTERFACE_NAMES = list(INTERFACE_MAP)
"""list[str]: allowed interface strings"""


_CACHED_EXECUTION_WITH_FINITE_SHOTS_WARNINGS = (
"Cached execution with finite shots detected!\n"
"Note that samples as well as all noisy quantities computed via sampling "
Expand Down Expand Up @@ -135,41 +131,41 @@ def _get_ml_boundary_execute(
pennylane.QuantumFunctionError if the required package is not installed.
"""
mapped_interface = INTERFACE_MAP[interface]
try:
if mapped_interface == "autograd":
if interface == "autograd":
from .interfaces.autograd import autograd_execute as ml_boundary

elif mapped_interface == "tf":
if "autograph" in interface:
from .interfaces.tensorflow_autograph import execute as ml_boundary
elif interface == "tf-autograph":
from .interfaces.tensorflow_autograph import execute as ml_boundary

ml_boundary = partial(ml_boundary, grad_on_execution=grad_on_execution)
ml_boundary = partial(ml_boundary, grad_on_execution=grad_on_execution)

else:
from .interfaces.tensorflow import tf_execute as full_ml_boundary
elif interface == "tf":
from .interfaces.tensorflow import tf_execute as full_ml_boundary

ml_boundary = partial(full_ml_boundary, differentiable=differentiable)
ml_boundary = partial(full_ml_boundary, differentiable=differentiable)

elif mapped_interface == "torch":
elif interface == "torch":
from .interfaces.torch import execute as ml_boundary

elif interface == "jax-jit":
if device_vjp:
from .interfaces.jax_jit import jax_jit_vjp_execute as ml_boundary
else:
from .interfaces.jax_jit import jax_jit_jvp_execute as ml_boundary
else: # interface in {"jax", "jax-python", "JAX"}:

else: # interface is jax
if device_vjp:
from .interfaces.jax_jit import jax_jit_vjp_execute as ml_boundary
else:
from .interfaces.jax import jax_jvp_execute as ml_boundary

except ImportError as e: # pragma: no cover
raise qml.QuantumFunctionError(
f"{mapped_interface} not found. Please install the latest "
f"version of {mapped_interface} to enable the '{mapped_interface}' interface."
f"{interface} not found. Please install the latest "
f"version of {interface} to enable the '{interface}' interface."
) from e

return ml_boundary


Expand Down Expand Up @@ -263,12 +259,22 @@ def _get_interface_name(tapes, interface):
Returns:
str: Interface name"""

if interface not in SUPPORTED_INTERFACE_NAMES:
raise qml.QuantumFunctionError(
f"Unknown interface {interface}. Interface must be one of {SUPPORTED_INTERFACE_NAMES}."
)

interface = INTERFACE_MAP[interface]

if interface == "auto":
params = []
for tape in tapes:
params.extend(tape.get_parameters(trainable_only=False))
interface = qml.math.get_interface(*params)
if INTERFACE_MAP.get(interface, "") == "tf" and _use_tensorflow_autograph():
if interface != "numpy":
interface = INTERFACE_MAP[interface]
if interface == "tf" and _use_tensorflow_autograph():
interface = "tf-autograph"
if interface == "jax":
try: # pragma: no cover
Expand Down Expand Up @@ -439,6 +445,7 @@ def cost_fn(params, x):

### Specifying and preprocessing variables ####

_interface_user_input = interface
interface = _get_interface_name(tapes, interface)
# Only need to calculate derivatives with jax when we know it will be executed later.
if interface in {"jax", "jax-jit"}:
Expand All @@ -460,7 +467,11 @@ def cost_fn(params, x):
)

# Mid-circuit measurement configuration validation
mcm_interface = interface or _get_interface_name(tapes, "auto")
# If the user specifies `interface=None`, regular execution considers it numpy, but the mcm
# workflow still needs to know if jax-jit is used
mcm_interface = (
_get_interface_name(tapes, "auto") if _interface_user_input is None else interface
)
finite_shots = any(tape.shots for tape in tapes)
_update_mcm_config(config.mcm_config, mcm_interface, finite_shots)

Expand All @@ -479,12 +490,12 @@ def cost_fn(params, x):
cache = None

# changing this set of conditions causes a bunch of tests to break.
no_interface_boundary_required = interface is None or config.gradient_method in {
no_interface_boundary_required = interface == "numpy" or config.gradient_method in {
None,
"backprop",
}
device_supports_interface_data = no_interface_boundary_required and (
interface is None
interface == "numpy"
or config.gradient_method == "backprop"
or getattr(device, "short_name", "") == "default.mixed"
)
Expand All @@ -497,9 +508,9 @@ def cost_fn(params, x):
numpy_only=not device_supports_interface_data,
)

# moved to its own explicit step so it will be easier to remove
# moved to its own explicit step so that it will be easier to remove
def inner_execute_with_empty_jac(tapes, **_):
return (inner_execute(tapes), [])
return inner_execute(tapes), []

if interface in jpc_interfaces:
execute_fn = inner_execute
Expand All @@ -522,7 +533,7 @@ def inner_execute_with_empty_jac(tapes, **_):
and getattr(device, "short_name", "") in ("lightning.gpu", "lightning.kokkos")
and interface in jpc_interfaces
): # pragma: no cover
if INTERFACE_MAP[interface] == "jax" and "use_device_state" in gradient_kwargs:
if "jax" in interface and "use_device_state" in gradient_kwargs:
gradient_kwargs["use_device_state"] = False

jpc = LightningVJPs(device, gradient_kwargs=gradient_kwargs)
Expand Down Expand Up @@ -563,7 +574,7 @@ def execute_fn(internal_tapes) -> tuple[ResultBatch, tuple]:
config: the ExecutionConfig that specifies how to perform the simulations.
"""
numpy_tapes, _ = qml.transforms.convert_to_numpy_parameters(internal_tapes)
return (device.execute(numpy_tapes, config), tuple())
return device.execute(numpy_tapes, config), tuple()

def gradient_fn(internal_tapes):
"""A partial function that wraps compute_derivatives method of the device.
Expand Down Expand Up @@ -612,7 +623,7 @@ def gradient_fn(internal_tapes):

# trainable parameters can only be set on the first pass for jax
# not higher order passes for higher order derivatives
if interface in {"jax", "jax-python", "jax-jit"}:
if "jax" in interface:
for tape in tapes:
params = tape.get_parameters(trainable_only=False)
tape.trainable_params = qml.math.get_trainable_indices(params)
Expand Down
17 changes: 16 additions & 1 deletion pennylane/workflow/interfaces/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,21 @@ def autograd_execute(
return _execute(parameters, tuple(tapes), execute_fn, jpc)


def _to_autograd(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch:
"""Converts an arbitrary result batch to one with autograd arrays.
Args:
result (ResultBatch): a nested structure of lists, tuples, dicts, and numpy arrays
Returns:
ResultBatch: a nested structure of tuples, dicts, and jax arrays
"""
if isinstance(result, dict):
return result
# pylint: disable=no-member
if isinstance(result, (list, tuple, autograd.builtins.tuple, autograd.builtins.list)):
return tuple(_to_autograd(r) for r in result)
return autograd.numpy.array(result)


@autograd.extend.primitive
def _execute(
parameters,
Expand All @@ -165,7 +180,7 @@ def _execute(
for the input tapes.
"""
return execute_fn(tapes)
return _to_autograd(execute_fn(tapes))


# pylint: disable=unused-argument
Expand Down
Loading

0 comments on commit 228fdaf

Please sign in to comment.