Skip to content

Commit

Permalink
Fixing issue keras-team#20021: Unexpected result on keras.ops.nn.soft…
Browse files Browse the repository at this point in the history
…max when running with tuple axis (keras-team#20022)

* Fixing issue keras-team#20021: Unexpected result on keras.ops.nn.softmax when running with tuple axis

* Fixing issue keras-team#20021: Unexpected result on keras.ops.nn.softmax when running with tuple axis

* Fixing issue keras-team#20021: Unexpected result on keras.ops.nn.softmax when running with tuple axis

* Fixing issue keras-team#20021: Unexpected result on keras.ops.nn.softmax when running with tuple axis

* Fixing issue keras-team#20021: Unexpected result on keras.ops.nn.softmax when running with tuple axis

* Fixing issue keras-team#20021: Unexpected result on keras.ops.nn.softmax when running with tuple axis

* Fixing issue keras-team#20021: Unexpected result on keras.ops.nn.softmax when running with tuple axis

* Fixing issue keras-team#20021: Unexpected result on keras.ops.nn.softmax when running with tuple axis

* Fixing issue keras-team#20021: Unexpected result on keras.ops.nn.softmax when running with tuple axis
  • Loading branch information
shashaka authored Jul 22, 2024
1 parent 36a0628 commit 93786e3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
30 changes: 13 additions & 17 deletions keras/src/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,23 +552,19 @@ def softmax(x, axis=-1):
if any_symbolic_tensors((x,)):
return Softmax(axis).symbolic_call(x)
if isinstance(axis, tuple):
original_shape = x.shape
new_shape = []
skip_dims = set(axis)
i = 0
while i < len(original_shape):
if i in skip_dims:
size = 1
while i in skip_dims:
size *= original_shape[i]
i += 1
new_shape.append(size)
else:
new_shape.append(original_shape[i])
i += 1
x = backend.numpy.reshape(x, new_shape)
x = backend.nn.softmax(x, axis=-1)
x = backend.numpy.reshape(x, original_shape)
axis_to_keep = [v for v in range(len(x.shape)) if v not in axis]

x_transposed = backend.numpy.transpose(x, axes=(*axis_to_keep, *axis))
x_reshaped = backend.numpy.reshape(
x_transposed, (*[x.shape[v] for v in axis_to_keep], -1)
)

x = backend.nn.softmax(x_reshaped, axis=-1)

x = backend.numpy.reshape(x, x_transposed.shape)
x = backend.numpy.transpose(
x, axes=list(backend.numpy.argsort([*axis_to_keep, *axis]))
)
return x
else:
return backend.nn.softmax(x, axis=axis)
Expand Down
10 changes: 10 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from itertools import combinations

import numpy as np
import pytest
from absl.testing import parameterized
Expand Down Expand Up @@ -1248,6 +1250,14 @@ def test_softmax(self):
],
)

def test_softmax_correctness_with_axis_tuple(self):
input = np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
combination = combinations(range(3), 2)
for axis in list(combination):
result = keras.ops.nn.softmax(input, axis=axis)
normalized_sum_by_axis = np.sum(np.asarray(result), axis=axis)
self.assertAllClose(normalized_sum_by_axis, 1.0)

def test_log_softmax(self):
x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
self.assertAllClose(
Expand Down

0 comments on commit 93786e3

Please sign in to comment.