From 4b3a65a63e06d754f38645f667ecb0e89251908a Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sun, 11 Aug 2024 15:28:59 -0700 Subject: [PATCH] Fix map nested structure test. --- .../tensorflow/nestedstructure_test.py | 55 ------------------- keras/src/ops/core_test.py | 36 ++++++++++++ 2 files changed, 36 insertions(+), 55 deletions(-) delete mode 100644 keras/src/backend/tensorflow/nestedstructure_test.py diff --git a/keras/src/backend/tensorflow/nestedstructure_test.py b/keras/src/backend/tensorflow/nestedstructure_test.py deleted file mode 100644 index 244e94f6b382..000000000000 --- a/keras/src/backend/tensorflow/nestedstructure_test.py +++ /dev/null @@ -1,55 +0,0 @@ -import numpy as np -import pytest -import tensorflow as tf - -import keras -from keras.src import backend -from keras.src import testing - - -@pytest.mark.skipif( - backend.backend() != "tensorflow", - reason="The nestedstructure test can only run with TF backend.", -) -def dict_input_fn(inputs): - x = inputs["x"][:, 0] - y = inputs["y"] + 1 - return {"x": x, "y": y} - - -def list_input_fn(inputs): - return [x**2 for x in inputs] - - -class NestedTest(testing.TestCase): - def setUp(self): - self.xs = { - "x": tf.convert_to_tensor( - np.random.rand(4, 100, 3), dtype=tf.float32 - ), - "y": tf.convert_to_tensor( - np.random.randint(0, 10, size=(4, 1)), dtype=tf.int32 - ), - } - self.xs1 = [ - tf.convert_to_tensor(np.random.rand(4, 100, 3), dtype=tf.float32), - tf.convert_to_tensor( - np.random.randint(0, 10, size=(4, 1)), dtype=tf.int32 - ), - ] - - def test_dict_input_fn_outputs(self): - ys = keras.ops.map(dict_input_fn, self.xs) - self.assertEqual(ys["x"].shape, (4, 100)) - self.assertEqual( - keras.ops.convert_to_numpy(ys["y"]).all(), - keras.ops.convert_to_numpy(self.xs["y"] + 1).all(), - ) - - def test_list_input_fn_outputs(self): - ys = keras.ops.map(list_input_fn, self.xs1) - for i, (x, y) in enumerate(zip(self.xs1, ys)): - self.assertEqual( - (keras.ops.convert_to_numpy(y)).all(), - (keras.ops.convert_to_numpy(x) ** 2).all(), - ) diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index c02f75e614cf..f0dc56150e64 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -161,6 +161,42 @@ def f2(x): self.assertAllClose(outputs["a"], xs**2) self.assertAllClose(outputs["b"], xs * 10) + # Test with nested structures + def dict_input_fn(inputs): + x = inputs["x"][:, 0] + y = inputs["y"] + 1 + return {"x": x, "y": y} + + def list_input_fn(inputs): + return [x**2 for x in inputs] + + xs = { + "x": ops.convert_to_tensor( + np.random.rand(4, 100, 3), dtype="float32" + ), + "y": ops.convert_to_tensor( + np.random.randint(0, 10, size=(4, 1)), dtype="int32" + ), + } + xs1 = [ + ops.convert_to_tensor(np.random.rand(4, 100, 3), dtype="float32"), + ops.convert_to_tensor( + np.random.randint(0, 10, size=(4, 1)), dtype="int32" + ), + ] + ys = ops.map(dict_input_fn, xs) + self.assertEqual(ys["x"].shape, (4, 100)) + self.assertEqual( + ops.convert_to_numpy(ys["y"]).all(), + ops.convert_to_numpy(xs["y"] + 1).all(), + ) + ys = ops.map(list_input_fn, xs1) + for x, y in zip(xs1, ys): + self.assertEqual( + (ops.convert_to_numpy(y)).all(), + (ops.convert_to_numpy(x) ** 2).all(), + ) + def test_scan(self): # Test cumsum def cumsum(carry, xs):