From bea44ba8cdbd0a7766084c8aebbd9c5737c54daf Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 3 Jun 2024 17:04:28 -0400 Subject: [PATCH] Found a way to raise error with diff_method=None --- pennylane/workflow/execution.py | 62 ++++++++++++++++++++++----------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index c87ffe2600d..a3f7dc845d9 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -368,6 +368,36 @@ def execution_function_with_caching(tapes): return execution_function_with_caching +def _get_interface_name(tapes, interface): + """Helper function to get the interface name of a list of tapes + + Args: + tapes (list[.QuantumScript]): Quantum tapes + interface (Optional[str]): Original interface to use as reference. + + Returns: + str: Interface name""" + if interface == "auto": + params = [] + for tape in tapes: + params.extend(tape.get_parameters(trainable_only=False)) + interface = qml.math.get_interface(*params) + if INTERFACE_MAP.get(interface, "") == "tf" and _use_tensorflow_autograph(): + interface = "tf-autograph" + if interface == "jax": + try: # pragma: no-cover + from .interfaces.jax import get_jax_interface_name + except ImportError as e: # pragma: no-cover + raise qml.QuantumFunctionError( # pragma: no-cover + "jax not found. Please install the latest " # pragma: no-cover + "version of jax to enable the 'jax' interface." # pragma: no-cover + ) from e # pragma: no-cover + + interface = get_jax_interface_name(tapes) + + return interface + + def execute( tapes: Sequence[QuantumTape], device: device_type, @@ -522,26 +552,10 @@ def cost_fn(params, x): ### Specifying and preprocessing variables #### - if interface == "auto": - params = [] - for tape in tapes: - params.extend(tape.get_parameters(trainable_only=False)) - interface = qml.math.get_interface(*params) - if INTERFACE_MAP.get(interface, "") == "tf" and _use_tensorflow_autograph(): - interface = "tf-autograph" - if interface == "jax": - try: # pragma: no-cover - from .interfaces.jax import get_jax_interface_name - except ImportError as e: # pragma: no-cover - raise qml.QuantumFunctionError( # pragma: no-cover - "jax not found. Please install the latest " # pragma: no-cover - "version of jax to enable the 'jax' interface." # pragma: no-cover - ) from e # pragma: no-cover - - interface = get_jax_interface_name(tapes) - # Only need to calculate derivatives with jax when we know it will be executed later. - if interface in {"jax", "jax-jit"}: - grad_on_execution = grad_on_execution if isinstance(gradient_fn, Callable) else False + interface = _get_interface_name(tapes, interface) + # Only need to calculate derivatives with jax when we know it will be executed later. + if interface in {"jax", "jax-jit"}: + grad_on_execution = grad_on_execution if isinstance(gradient_fn, Callable) else False if ( device_vjp @@ -558,7 +572,13 @@ def cost_fn(params, x): gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) - if interface == "jax-jit" and config.mcm_config.mcm_method == "deferred": + # Mid-circuit measurement configuration validation + if interface is None: + mcm_interface = "auto" + mcm_interface = _get_interface_name(tapes, mcm_interface) + else: + mcm_interface = interface + if mcm_interface == "jax-jit" and config.mcm_config.mcm_method == "deferred": # This is a current limitation of defer_measurements. "hw-like" behaviour is # not yet accessible. if config.mcm_config.postselect_mode == "hw-like":