diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index 9ee94de7cf2..720dd965d0c 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -43,6 +43,10 @@ def in_top_k(targets, predictions, k): preds_at_label = jnp.take_along_axis( predictions, jnp.expand_dims(targets, axis=-1), axis=-1 ) + # `nan` shouldn't be considered as large probability. + preds_at_label = jnp.where( + jnp.isnan(preds_at_label), -jnp.inf, preds_at_label + ) rank = 1 + jnp.sum(jnp.greater(predictions, preds_at_label), axis=-1) return jnp.less_equal(rank, k) diff --git a/keras/src/backend/numpy/math.py b/keras/src/backend/numpy/math.py index 630b7912695..c75292c34d2 100644 --- a/keras/src/backend/numpy/math.py +++ b/keras/src/backend/numpy/math.py @@ -53,26 +53,17 @@ def segment_max(data, segment_ids, num_segments=None, sorted=False): def top_k(x, k, sorted=False): - sorted_indices = np.argsort(x, axis=-1)[..., ::-1] - sorted_values = np.sort(x, axis=-1)[..., ::-1] - if sorted: # Take the k largest values. + sorted_indices = np.argsort(x, axis=-1)[..., ::-1] + sorted_values = np.take_along_axis(x, sorted_indices, axis=-1) top_k_values = sorted_values[..., :k] top_k_indices = sorted_indices[..., :k] else: # Partition the array such that all values larger than the k-th # largest value are to the right of it. - top_k_values = np.partition(x, -k, axis=-1)[..., -k:] top_k_indices = np.argpartition(x, -k, axis=-1)[..., -k:] - - # Get the indices in sorted order. - idx = np.argsort(-top_k_values, axis=-1) - - # Get the top k values and their indices. - top_k_values = np.take_along_axis(top_k_values, idx, axis=-1) - top_k_indices = np.take_along_axis(top_k_indices, idx, axis=-1) - + top_k_values = np.take_along_axis(x, top_k_indices, axis=-1) return top_k_values, top_k_indices diff --git a/keras/src/ops/math_test.py b/keras/src/ops/math_test.py index 78caa448be3..7acbdb1b8b7 100644 --- a/keras/src/ops/math_test.py +++ b/keras/src/ops/math_test.py @@ -632,6 +632,14 @@ def test_in_top_k(self): kmath.in_top_k(targets, predictions, k=3), [True, True, True] ) + # Test `nan` in predictions + # https://github.com/keras-team/keras/issues/19995 + targets = np.array([1, 0]) + predictions = np.array([[0.1, np.nan, 0.5], [0.3, 0.2, 0.5]]) + self.assertAllEqual( + kmath.in_top_k(targets, predictions, k=2), [False, True] + ) + def test_logsumexp(self): x = np.random.rand(5, 5) outputs = kmath.logsumexp(x)