Skip to content

Commit

Permalink
Found a way to raise error with diff_method=None
Browse files Browse the repository at this point in the history
  • Loading branch information
mudit2812 committed Jun 3, 2024
1 parent e0555e5 commit bea44ba
Showing 1 changed file with 41 additions and 21 deletions.
62 changes: 41 additions & 21 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,36 @@ def execution_function_with_caching(tapes):
return execution_function_with_caching


def _get_interface_name(tapes, interface):
"""Helper function to get the interface name of a list of tapes
Args:
tapes (list[.QuantumScript]): Quantum tapes
interface (Optional[str]): Original interface to use as reference.
Returns:
str: Interface name"""
if interface == "auto":
params = []
for tape in tapes:
params.extend(tape.get_parameters(trainable_only=False))
interface = qml.math.get_interface(*params)
if INTERFACE_MAP.get(interface, "") == "tf" and _use_tensorflow_autograph():
interface = "tf-autograph"
if interface == "jax":
try: # pragma: no-cover
from .interfaces.jax import get_jax_interface_name
except ImportError as e: # pragma: no-cover
raise qml.QuantumFunctionError( # pragma: no-cover
"jax not found. Please install the latest " # pragma: no-cover
"version of jax to enable the 'jax' interface." # pragma: no-cover
) from e # pragma: no-cover

interface = get_jax_interface_name(tapes)

return interface


def execute(
tapes: Sequence[QuantumTape],
device: device_type,
Expand Down Expand Up @@ -522,26 +552,10 @@ def cost_fn(params, x):

### Specifying and preprocessing variables ####

if interface == "auto":
params = []
for tape in tapes:
params.extend(tape.get_parameters(trainable_only=False))
interface = qml.math.get_interface(*params)
if INTERFACE_MAP.get(interface, "") == "tf" and _use_tensorflow_autograph():
interface = "tf-autograph"
if interface == "jax":
try: # pragma: no-cover
from .interfaces.jax import get_jax_interface_name
except ImportError as e: # pragma: no-cover
raise qml.QuantumFunctionError( # pragma: no-cover
"jax not found. Please install the latest " # pragma: no-cover
"version of jax to enable the 'jax' interface." # pragma: no-cover
) from e # pragma: no-cover

interface = get_jax_interface_name(tapes)
# Only need to calculate derivatives with jax when we know it will be executed later.
if interface in {"jax", "jax-jit"}:
grad_on_execution = grad_on_execution if isinstance(gradient_fn, Callable) else False
interface = _get_interface_name(tapes, interface)
# Only need to calculate derivatives with jax when we know it will be executed later.
if interface in {"jax", "jax-jit"}:
grad_on_execution = grad_on_execution if isinstance(gradient_fn, Callable) else False

if (
device_vjp
Expand All @@ -558,7 +572,13 @@ def cost_fn(params, x):
gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config
)

if interface == "jax-jit" and config.mcm_config.mcm_method == "deferred":
# Mid-circuit measurement configuration validation
if interface is None:
mcm_interface = "auto"
mcm_interface = _get_interface_name(tapes, mcm_interface)
else:
mcm_interface = interface
if mcm_interface == "jax-jit" and config.mcm_config.mcm_method == "deferred":
# This is a current limitation of defer_measurements. "hw-like" behaviour is
# not yet accessible.
if config.mcm_config.postselect_mode == "hw-like":
Expand Down

0 comments on commit bea44ba

Please sign in to comment.