diff --git a/pennylane/workflow/construct_batch.py b/pennylane/workflow/construct_batch.py index b57c4f15dcb..c12a07427ec 100644 --- a/pennylane/workflow/construct_batch.py +++ b/pennylane/workflow/construct_batch.py @@ -18,8 +18,6 @@ from functools import wraps from typing import Callable, Literal, Optional, Tuple, Union -import tensorflow as tf - import pennylane as qml from .qnode import QNode, _get_device_shots, _make_execution_config @@ -299,6 +297,9 @@ def batch_constructor(*args, **kwargs) -> Tuple[Tuple["qml.tape.QuantumTape", Ca shots = kwargs.pop("shots", _get_device_shots(qnode.device)) if isinstance(qnode, qml.qnn.KerasLayer): + # pylint: disable=import-outside-toplevel + import tensorflow as tf + with tf.GradientTape() as tape: tape.watch(list(qnode.qnode_weights.values()))