Skip to content

Commit

Permalink
Update TF version check to check Keras (#5488)
Browse files Browse the repository at this point in the history
**Context:**
If you `pip install tensorflow` (or update) following their most recent
Tensorflow release, you get Keras 3 instead of Keras 2. Currently the
`KerasLayer` in the `qnn` module isn't compatible with Keras 3.

**Description of the Change:**
We add a check to make sure that users have the correct version of
Keras. If we find they are using Keras 3, the error message directs to
the official documentation for running Keras 2 instead, see
[here](https://keras.io/getting_started/#tensorflow--keras-2-backwards-compatibility).

**Testing:**
The unit-testing for this behaviour isn't great, because it just mocks
the result of checking the versions (i.e. changes the value of
`CORRECT_TF_VERSION` and `CORRECT_KERAS_VERSION`) and checks that the
error is raised, rather than checking that the version checks themselves
work as expected. The behaviour occurs on import, so a workaround to
test the behaviour isn't, to my knowledge, straightforward.

I've tested locally by modifying my testing environment, and confirmed
that:

1. With only `Keras 2` installed, the code runs
2. With only `Keras 3` installed, an error directs to the Keras official
documentation on using Keras 2 instead
3. With `tf_keras` and `Keras 3`, but without the correct variable set
to use legacy Keras, an error directs to the Keras official
documentation on using `Keras 2` instead (which includes info about
setting the correct global variable)
4. With `tf_keras` and `Keras 3`, and the variable to use legacy Keras
set as instructed in the docs, the code runs

**Benefits:**
People won't get weird errors if they install/upgrade TensorFlow and try
to use the QNN module. Instead, a clear error will tell them that the
problem is their environment setup and how to fix it.

**Possible Drawbacks:**
This is short-time solution and doesn't fix the underlying problem that
we aren't currently compatible with Keras 3. That is tracked in a story
[here](https://app.shortcut.com/xanaduai/story/59479/update-keraslayer-to-support-keras3)
for later.
  • Loading branch information
lillian542 committed Apr 11, 2024
1 parent 6429086 commit 9e0781e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@
[(#5256)](https://github.com/PennyLaneAI/pennylane/pull/5256)
[(#5395)](https://github.com/PennyLaneAI/pennylane/pull/5395)

* A clear error message is added in `KerasLayer` when using the newest version of TensorFlow with Keras 3
(which is not currently compatible with `KerasLayer`), linking to instructions to enable Keras 2.
[(#5488)](https://github.com/PennyLaneAI/pennylane/pull/5488)

<h3>Breaking changes 💔</h3>

* The private functions `_pauli_mult`, `_binary_matrix` and `_get_pauli_map` from the `pauli` module have been removed. The same functionality can be achieved using newer features in the ``pauli`` module.
Expand Down Expand Up @@ -380,6 +384,7 @@ Astral Cai,
Isaac De Vlugt,
Amintor Dusko,
Pietropaolo Frisoni,
Lillian M. A. Frederiksen,
Soran Jahangiri,
Korbinian Kottmann,
Christina Lee,
Expand Down
26 changes: 25 additions & 1 deletion pennylane/qnn/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,23 @@
import inspect
from collections.abc import Iterable
from typing import Optional, Text
from semantic_version import Version


try:
import tensorflow as tf
from tensorflow.keras.layers import Layer

CORRECT_TF_VERSION = int(tf.__version__.split(".", maxsplit=1)[0]) > 1
CORRECT_TF_VERSION = Version(tf.__version__) >= Version("2.0.0")
try:
# this feels a bit hacky, but if users *only* have an old (i.e. PL-compatible) version of Keras installed
# then tf.keras doesn't have a version attribute, and we *should be* good to go.
# if you have a newer version of Keras installed, then you can use tf.keras.version to check if you
# are configured to use Keras 3 or Keras 2
CORRECT_KERAS_VERSION = Version(tf.keras.version()) < Version("3.0.0")
except AttributeError:
CORRECT_KERAS_VERSION = True

except ImportError:
# The following allows this module to be imported even if TensorFlow is not installed. Users
# will instead see an ImportError when instantiating the KerasLayer.
Expand All @@ -40,6 +51,13 @@ class KerasLayer(Layer):
`Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ classes for
creating quantum and hybrid models.
.. note::
``KerasLayer`` currently only supports Keras 2. If you are running the newest version
of TensorFlow and Keras, you may automatically be using Keras 3. For instructions
on running with Keras 2, instead, see the
`documentation on backwards compatibility <https://keras.io/getting_started/#tensorflow--keras-2-backwards-compatibility>`__ .
Args:
qnode (qml.QNode): the PennyLane QNode to be converted into a Keras Layer_
weight_shapes (dict[str, tuple]): a dictionary mapping from all weights used in the QNode to
Expand Down Expand Up @@ -299,6 +317,12 @@ def __init__(
"https://www.tensorflow.org/install for detailed instructions."
)

if not CORRECT_KERAS_VERSION:
raise ImportError(
"KerasLayer requires a Keras version lower than 3. For instructions on running with Keras 2,"
"visit https://keras.io/getting_started/#tensorflow--keras-2-backwards-compatibility."
)

self.weight_shapes = {
weight: (tuple(size) if isinstance(size, Iterable) else (size,) if size > 1 else ())
for weight, size in weight_shapes.items()
Expand Down
15 changes: 15 additions & 0 deletions tests/qnn/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def indices_up_to_dm(n_max):
return zip(*[a + 1], zip(*[2 ** (b + 1), 2 ** (b + 1)]))


# pylint: disable=too-many-public-methods
@pytest.mark.tf
@pytest.mark.parametrize("interface", ["tf"]) # required for the get_circuit fixture
@pytest.mark.usefixtures("get_circuit")
Expand All @@ -114,6 +115,20 @@ def test_bad_tf_version(
with pytest.raises(ImportError, match="KerasLayer requires TensorFlow version 2"):
KerasLayer(c, w, output_dim)

@pytest.mark.parametrize("n_qubits, output_dim", indices_up_to(1))
def test_bad_keras_version(
self, get_circuit, output_dim, monkeypatch
): # pylint: disable=no-self-use
"""Test if an ImportError is raised when instantiated with an incorrect version of
Keras."""
c, w = get_circuit
with monkeypatch.context() as m:
m.setattr(qml.qnn.keras, "CORRECT_KERAS_VERSION", False)
with pytest.raises(
ImportError, match="KerasLayer requires a Keras version lower than 3"
):
KerasLayer(c, w, output_dim)

@pytest.mark.parametrize("n_qubits, output_dim", indices_up_to(1))
def test_no_input(self): # pylint: disable=no-self-use
"""Test if a TypeError is raised when instantiated with a QNode that does not have an
Expand Down

0 comments on commit 9e0781e

Please sign in to comment.