diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 75bd2d872e9..b5aa5163048 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -219,6 +219,10 @@
[(#5256)](https://github.com/PennyLaneAI/pennylane/pull/5256)
[(#5395)](https://github.com/PennyLaneAI/pennylane/pull/5395)
+* A clear error message is added in `KerasLayer` when using the newest version of TensorFlow with Keras 3
+ (which is not currently compatible with `KerasLayer`), linking to instructions to enable Keras 2.
+ [(#5488)](https://github.com/PennyLaneAI/pennylane/pull/5488)
+
Breaking changes 💔
* The private functions `_pauli_mult`, `_binary_matrix` and `_get_pauli_map` from the `pauli` module have been removed. The same functionality can be achieved using newer features in the ``pauli`` module.
@@ -380,6 +384,7 @@ Astral Cai,
Isaac De Vlugt,
Amintor Dusko,
Pietropaolo Frisoni,
+Lillian M. A. Frederiksen,
Soran Jahangiri,
Korbinian Kottmann,
Christina Lee,
diff --git a/pennylane/qnn/keras.py b/pennylane/qnn/keras.py
index 6322d235094..1b533686454 100644
--- a/pennylane/qnn/keras.py
+++ b/pennylane/qnn/keras.py
@@ -16,12 +16,23 @@
import inspect
from collections.abc import Iterable
from typing import Optional, Text
+from semantic_version import Version
+
try:
import tensorflow as tf
from tensorflow.keras.layers import Layer
- CORRECT_TF_VERSION = int(tf.__version__.split(".", maxsplit=1)[0]) > 1
+ CORRECT_TF_VERSION = Version(tf.__version__) >= Version("2.0.0")
+ try:
+ # this feels a bit hacky, but if users *only* have an old (i.e. PL-compatible) version of Keras installed
+ # then tf.keras doesn't have a version attribute, and we *should be* good to go.
+ # if you have a newer version of Keras installed, then you can use tf.keras.version to check if you
+ # are configured to use Keras 3 or Keras 2
+ CORRECT_KERAS_VERSION = Version(tf.keras.version()) < Version("3.0.0")
+ except AttributeError:
+ CORRECT_KERAS_VERSION = True
+
except ImportError:
# The following allows this module to be imported even if TensorFlow is not installed. Users
# will instead see an ImportError when instantiating the KerasLayer.
@@ -40,6 +51,13 @@ class KerasLayer(Layer):
`Model `__ classes for
creating quantum and hybrid models.
+ .. note::
+
+ ``KerasLayer`` currently only supports Keras 2. If you are running the newest version
+ of TensorFlow and Keras, you may automatically be using Keras 3. For instructions
+ on running with Keras 2, instead, see the
+ `documentation on backwards compatibility `__ .
+
Args:
qnode (qml.QNode): the PennyLane QNode to be converted into a Keras Layer_
weight_shapes (dict[str, tuple]): a dictionary mapping from all weights used in the QNode to
@@ -299,6 +317,12 @@ def __init__(
"https://www.tensorflow.org/install for detailed instructions."
)
+ if not CORRECT_KERAS_VERSION:
+ raise ImportError(
+ "KerasLayer requires a Keras version lower than 3. For instructions on running with Keras 2,"
+ "visit https://keras.io/getting_started/#tensorflow--keras-2-backwards-compatibility."
+ )
+
self.weight_shapes = {
weight: (tuple(size) if isinstance(size, Iterable) else (size,) if size > 1 else ())
for weight, size in weight_shapes.items()
diff --git a/tests/qnn/test_keras.py b/tests/qnn/test_keras.py
index a75635ef18f..18d28dbc8ae 100644
--- a/tests/qnn/test_keras.py
+++ b/tests/qnn/test_keras.py
@@ -96,6 +96,7 @@ def indices_up_to_dm(n_max):
return zip(*[a + 1], zip(*[2 ** (b + 1), 2 ** (b + 1)]))
+# pylint: disable=too-many-public-methods
@pytest.mark.tf
@pytest.mark.parametrize("interface", ["tf"]) # required for the get_circuit fixture
@pytest.mark.usefixtures("get_circuit")
@@ -114,6 +115,20 @@ def test_bad_tf_version(
with pytest.raises(ImportError, match="KerasLayer requires TensorFlow version 2"):
KerasLayer(c, w, output_dim)
+ @pytest.mark.parametrize("n_qubits, output_dim", indices_up_to(1))
+ def test_bad_keras_version(
+ self, get_circuit, output_dim, monkeypatch
+ ): # pylint: disable=no-self-use
+ """Test if an ImportError is raised when instantiated with an incorrect version of
+ Keras."""
+ c, w = get_circuit
+ with monkeypatch.context() as m:
+ m.setattr(qml.qnn.keras, "CORRECT_KERAS_VERSION", False)
+ with pytest.raises(
+ ImportError, match="KerasLayer requires a Keras version lower than 3"
+ ):
+ KerasLayer(c, w, output_dim)
+
@pytest.mark.parametrize("n_qubits, output_dim", indices_up_to(1))
def test_no_input(self): # pylint: disable=no-self-use
"""Test if a TypeError is raised when instantiated with a QNode that does not have an