Skip to content

Commit

Permalink
localize tensorflow import in construct_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
Shiro-Raven committed May 31, 2024
1 parent c7abcf2 commit 14696c8
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions pennylane/workflow/construct_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from functools import wraps
from typing import Callable, Literal, Optional, Tuple, Union

import tensorflow as tf

import pennylane as qml

from .qnode import QNode, _get_device_shots, _make_execution_config
Expand Down Expand Up @@ -299,6 +297,9 @@ def batch_constructor(*args, **kwargs) -> Tuple[Tuple["qml.tape.QuantumTape", Ca
shots = kwargs.pop("shots", _get_device_shots(qnode.device))

if isinstance(qnode, qml.qnn.KerasLayer):
# pylint: disable=import-outside-toplevel
import tensorflow as tf

with tf.GradientTape() as tape:
tape.watch(list(qnode.qnode_weights.values()))

Expand Down

0 comments on commit 14696c8

Please sign in to comment.