Skip to content

Commit

Permalink
Fix test_softmax_correctness_with_axis_tuple test case failed on torc…
Browse files Browse the repository at this point in the history
…h gpu ci. (keras-team#20025)
  • Loading branch information
shashaka committed Jul 22, 2024
1 parent 93786e3 commit f9501f5
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,9 @@ def test_softmax_correctness_with_axis_tuple(self):
combination = combinations(range(3), 2)
for axis in list(combination):
result = keras.ops.nn.softmax(input, axis=axis)
normalized_sum_by_axis = np.sum(np.asarray(result), axis=axis)
normalized_sum_by_axis = np.sum(
ops.convert_to_numpy(result), axis=axis
)
self.assertAllClose(normalized_sum_by_axis, 1.0)

def test_log_softmax(self):
Expand Down

0 comments on commit f9501f5

Please sign in to comment.