From 3d9143ae4f3ed9a3b88f89fbf42086b8e3a982fb Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Wed, 19 Jun 2024 11:31:54 +0100 Subject: [PATCH] Catch another tf.keras usage --- setup.py | 4 ++-- tests/gpflux/layers/test_trackable_layer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index ad04215a..84d46e1c 100644 --- a/setup.py +++ b/setup.py @@ -9,9 +9,9 @@ "gpflow>=2.9.2", "numpy<2", "scipy", - "tensorflow>=2.5.0,<=2.16.1; platform_system!='Darwin' or platform_machine!='arm64'", + "tensorflow>=2.5.0,<2.17; platform_system!='Darwin' or platform_machine!='arm64'", # NOTE: Support of Apple Silicon MacOS platforms is in an experimental mode - "tensorflow-macos>=2.5.0,<=2.16.1; platform_system=='Darwin' and platform_machine=='arm64'", + "tensorflow-macos>=2.5.0,<2.17; platform_system=='Darwin' and platform_machine=='arm64'", "tensorflow-probability>=0.13.0,<=0.24", ] diff --git a/tests/gpflux/layers/test_trackable_layer.py b/tests/gpflux/layers/test_trackable_layer.py index 80174ff6..c78bf0b4 100644 --- a/tests/gpflux/layers/test_trackable_layer.py +++ b/tests/gpflux/layers/test_trackable_layer.py @@ -18,9 +18,9 @@ import pytest import tensorflow as tf from packaging.version import Version -from tensorflow.keras.layers import Layer from gpflow import default_float +from gpflow.keras import tf_keras from gpflow.kernels import RBF, Matern12, Matern52 import gpflux @@ -34,7 +34,7 @@ def __init__(self, attributes): super().__init__() -class UntrackableCompositeLayer(Layer): +class UntrackableCompositeLayer(tf_keras.Layer): def __init__(self, attributes): for i, a in enumerate(attributes): setattr(self, f"var_{i}", a)