Skip to content

Commit

Permalink
Fix jax's in_top_k and numpy's top_k (keras-team#20033)
Browse files Browse the repository at this point in the history
* Fix jax's

* Fix numpy's
  • Loading branch information
james77777778 committed Jul 23, 2024
1 parent 7b516d0 commit 9ebc65a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
4 changes: 4 additions & 0 deletions keras/src/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def in_top_k(targets, predictions, k):
preds_at_label = jnp.take_along_axis(
predictions, jnp.expand_dims(targets, axis=-1), axis=-1
)
# `nan` shouldn't be considered as large probability.
preds_at_label = jnp.where(
jnp.isnan(preds_at_label), -jnp.inf, preds_at_label
)
rank = 1 + jnp.sum(jnp.greater(predictions, preds_at_label), axis=-1)
return jnp.less_equal(rank, k)

Expand Down
15 changes: 3 additions & 12 deletions keras/src/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,17 @@ def segment_max(data, segment_ids, num_segments=None, sorted=False):


def top_k(x, k, sorted=False):
sorted_indices = np.argsort(x, axis=-1)[..., ::-1]
sorted_values = np.sort(x, axis=-1)[..., ::-1]

if sorted:
# Take the k largest values.
sorted_indices = np.argsort(x, axis=-1)[..., ::-1]
sorted_values = np.take_along_axis(x, sorted_indices, axis=-1)
top_k_values = sorted_values[..., :k]
top_k_indices = sorted_indices[..., :k]
else:
# Partition the array such that all values larger than the k-th
# largest value are to the right of it.
top_k_values = np.partition(x, -k, axis=-1)[..., -k:]
top_k_indices = np.argpartition(x, -k, axis=-1)[..., -k:]

# Get the indices in sorted order.
idx = np.argsort(-top_k_values, axis=-1)

# Get the top k values and their indices.
top_k_values = np.take_along_axis(top_k_values, idx, axis=-1)
top_k_indices = np.take_along_axis(top_k_indices, idx, axis=-1)

top_k_values = np.take_along_axis(x, top_k_indices, axis=-1)
return top_k_values, top_k_indices


Expand Down
8 changes: 8 additions & 0 deletions keras/src/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,14 @@ def test_in_top_k(self):
kmath.in_top_k(targets, predictions, k=3), [True, True, True]
)

# Test `nan` in predictions
# https://github.com/keras-team/keras/issues/19995
targets = np.array([1, 0])
predictions = np.array([[0.1, np.nan, 0.5], [0.3, 0.2, 0.5]])
self.assertAllEqual(
kmath.in_top_k(targets, predictions, k=2), [False, True]
)

def test_logsumexp(self):
x = np.random.rand(5, 5)
outputs = kmath.logsumexp(x)
Expand Down

0 comments on commit 9ebc65a

Please sign in to comment.