Skip to content

Commit

Permalink
Fix map nested structure test.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Aug 11, 2024
1 parent 01381ab commit 4b3a65a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 55 deletions.
55 changes: 0 additions & 55 deletions keras/src/backend/tensorflow/nestedstructure_test.py

This file was deleted.

36 changes: 36 additions & 0 deletions keras/src/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,42 @@ def f2(x):
self.assertAllClose(outputs["a"], xs**2)
self.assertAllClose(outputs["b"], xs * 10)

# Test with nested structures
def dict_input_fn(inputs):
x = inputs["x"][:, 0]
y = inputs["y"] + 1
return {"x": x, "y": y}

def list_input_fn(inputs):
return [x**2 for x in inputs]

xs = {
"x": ops.convert_to_tensor(
np.random.rand(4, 100, 3), dtype="float32"
),
"y": ops.convert_to_tensor(
np.random.randint(0, 10, size=(4, 1)), dtype="int32"
),
}
xs1 = [
ops.convert_to_tensor(np.random.rand(4, 100, 3), dtype="float32"),
ops.convert_to_tensor(
np.random.randint(0, 10, size=(4, 1)), dtype="int32"
),
]
ys = ops.map(dict_input_fn, xs)
self.assertEqual(ys["x"].shape, (4, 100))
self.assertEqual(
ops.convert_to_numpy(ys["y"]).all(),
ops.convert_to_numpy(xs["y"] + 1).all(),
)
ys = ops.map(list_input_fn, xs1)
for x, y in zip(xs1, ys):
self.assertEqual(
(ops.convert_to_numpy(y)).all(),
(ops.convert_to_numpy(x) ** 2).all(),
)

def test_scan(self):
# Test cumsum
def cumsum(carry, xs):
Expand Down

0 comments on commit 4b3a65a

Please sign in to comment.