Skip to content

Commit

Permalink
Part of Bears-R-Us#3229: Re-enable find implementation of indexof1d (
Browse files Browse the repository at this point in the history
…Bears-R-Us#3316)

This PR is part of Bears-R-Us#3229. Re-enable the `find` implementation. I modified it to have a more optimized way of finding all occurences (which will hopefully also solve the issue we were seeing in the CI). While I'm not sure if this is the best way of calculating this, it's no longer the bottleneck and it gets better performance than the current implementation.

Previously I added seeds to the `indexof1d` test, so it will be easier to reproduce any failures. I'm not closing the issue yet since there's still some code cleanup to be done, and I want to leave exisiting code to make it easier to fall back to a similar approach in the event that the `find` way still has issues.

Co-authored-by: Tess Hayes <stress-tess@users.noreply.github.com>
  • Loading branch information
stress-tess and stress-tess committed Jun 12, 2024
1 parent aed98ea commit 3043f38
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 60 deletions.
76 changes: 26 additions & 50 deletions arkouda/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,19 @@ def find(query, space, all_occurrences=False, remove_missing=False):
occurrences as a pdarray. Defaults to only finding the first occurrence.
Finding all occurrences is not yet supported on sequences of arrays
remove_missing: bool
If all_occurrences is True, remove_missing is automatically enabled.
If False, return -1 for any items in query not found in space. If True,
remove these and only return indices of items that are found.
Returns
-------
indices : pdarray or SegArray
For each item in query, its index in space. If remove_missing is True,
exclued missing values otherwise return -1. If all_occurrences is False,
For each item in query, its index in space. If all_occurrences is False,
the return will be a pdarray of the first index where each value in the
query appears in the space. if all_occurrences is True, the return will be
query appears in the space. If all_occurrences is True, the return will be
a SegArray containing every index where each value in the query appears in
the space.
the space. If all_occurrences is True, remove_missing is automatically enabled.
If remove_missing is True, exclude missing values, otherwise return -1.
Examples
--------
Expand All @@ -158,30 +159,6 @@ def find(query, space, all_occurrences=False, remove_missing=False):
>>> ak.find(arr1, arr2, remove_missing=True)
array([0 1 2 5 8 5 11 5 0])
# set all_occurrences to True, the first index of each list
# is the first occurence and should match the default
>>> ak.find(arr1, arr2, all_occurrences=True).to_list()
[[-1],
[-1],
[-1],
[0, 4],
[1, 3, 10],
[-1],
[-1],
[-1],
[2, 6, 12, 13],
[-1],
[5, 7],
[-1],
[8, 9, 14],
[-1],
[5, 7],
[-1],
[-1],
[11, 15],
[5, 7],
[0, 4]]
# set both remove_missing and all_occurrences to True, missing values
# will be empty segments
>>> ak.find(arr1, arr2, remove_missing=True, all_occurrences=True).to_list()
Expand Down Expand Up @@ -233,7 +210,21 @@ def find(query, space, all_occurrences=False, remove_missing=False):
# Group on terms
g = GroupBy(c, dropna=False)
# For each term, count how many times it appears in the search space
space_multiplicity = g.sum(i < spacesize)[1]

# since we reuse (i < spacesize)[g.permutation] later, we call sum aggregation manually
less_than = (i < spacesize)[g.permutation]
repMsg = generic_msg(
cmd="segmentedReduction",
args={
"values": less_than,
"segments": g.segments,
"op": "sum",
"skip_nan": True,
"ddof": 1,
},
)

space_multiplicity = create_pdarray(repMsg)
has_duplicates = (space_multiplicity > 1).any()
# handle duplicate terms in space
if has_duplicates:
Expand All @@ -243,28 +234,13 @@ def find(query, space, all_occurrences=False, remove_missing=False):

from arkouda.segarray import SegArray

# use segmented mink to select space_multiplicity number of elements
# and create a segarray which contains all the indices
# in our query space, instead of just the min for each segment

# only calculate where to place the negatives if remove_missing is false
negative_at = "" if remove_missing else space_multiplicity == 0
repMsg = generic_msg(
cmd="segmentedExtremaK",
args={
"vals": i[g.permutation],
"segs": g.segments,
"segLens": g.size()[1],
"kArray": space_multiplicity,
"isMin": True,
"removeMissing": remove_missing,
"negativeAt": negative_at,
},
)
min_k_vals = create_pdarray(repMsg)
# create a segarray which contains all the indices from query
# in our search space, instead of just the min for each segment

# im not completely convinced there's not a better way to get this given the
# amount of structure but this is not the bottleneck of the computation anymore
min_k_vals = i[g.permutation][less_than]
seg_idx = g.broadcast(arange(g.segments.size))[i >= spacesize]
if not remove_missing:
space_multiplicity += negative_at
min_k_segs = cumsum(space_multiplicity) - space_multiplicity
sa = SegArray(min_k_segs, min_k_vals)
return sa[seg_idx]
Expand Down
20 changes: 10 additions & 10 deletions arkouda/pdarraysetops.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ def indexof1d(query: groupable, space: groupable) -> pdarray:
RuntimeError
Raised if the dtype of either array is not supported
"""
# from arkouda.alignment import find as akfind
from arkouda.categorical import Categorical as Categorical_

if isinstance(query, (pdarray, Strings, Categorical_)):
Expand All @@ -284,15 +283,16 @@ def indexof1d(query: groupable, space: groupable) -> pdarray:
elif isinstance(query, pdarray) and not isinstance(space, pdarray):
raise TypeError("If keys is pdarray, arr must also be pdarray")

repMsg = generic_msg(
cmd="indexof1d",
args={"keys": query, "arr": space},
)
return create_pdarray(cast(str, repMsg))

# TODO see issue #3229 reverted back to old implementation until we can investigate
# found = akfind(query, space, all_occurrences=True, remove_missing=True)
# return found if isinstance(found, pdarray) else found.values
# repMsg = generic_msg(
# cmd="indexof1d",
# args={"keys": query, "arr": space},
# )
# return create_pdarray(cast(str, repMsg))

from arkouda.alignment import find as akfind

found = akfind(query, space, all_occurrences=True, remove_missing=True)
return found if isinstance(found, pdarray) else found.values


# fmt: off
Expand Down

0 comments on commit 3043f38

Please sign in to comment.