diff --git a/arkouda/alignment.py b/arkouda/alignment.py index 2d33566152..546a42dfa1 100644 --- a/arkouda/alignment.py +++ b/arkouda/alignment.py @@ -126,18 +126,19 @@ def find(query, space, all_occurrences=False, remove_missing=False): occurrences as a pdarray. Defaults to only finding the first occurrence. Finding all occurrences is not yet supported on sequences of arrays remove_missing: bool + If all_occurrences is True, remove_missing is automatically enabled. If False, return -1 for any items in query not found in space. If True, remove these and only return indices of items that are found. Returns ------- indices : pdarray or SegArray - For each item in query, its index in space. If remove_missing is True, - exclued missing values otherwise return -1. If all_occurrences is False, + For each item in query, its index in space. If all_occurrences is False, the return will be a pdarray of the first index where each value in the - query appears in the space. if all_occurrences is True, the return will be + query appears in the space. If all_occurrences is True, the return will be a SegArray containing every index where each value in the query appears in - the space. + the space. If all_occurrences is True, remove_missing is automatically enabled. + If remove_missing is True, exclude missing values, otherwise return -1. Examples -------- @@ -158,30 +159,6 @@ def find(query, space, all_occurrences=False, remove_missing=False): >>> ak.find(arr1, arr2, remove_missing=True) array([0 1 2 5 8 5 11 5 0]) - # set all_occurrences to True, the first index of each list - # is the first occurence and should match the default - >>> ak.find(arr1, arr2, all_occurrences=True).to_list() - [[-1], - [-1], - [-1], - [0, 4], - [1, 3, 10], - [-1], - [-1], - [-1], - [2, 6, 12, 13], - [-1], - [5, 7], - [-1], - [8, 9, 14], - [-1], - [5, 7], - [-1], - [-1], - [11, 15], - [5, 7], - [0, 4]] - # set both remove_missing and all_occurrences to True, missing values # will be empty segments >>> ak.find(arr1, arr2, remove_missing=True, all_occurrences=True).to_list() @@ -233,7 +210,21 @@ def find(query, space, all_occurrences=False, remove_missing=False): # Group on terms g = GroupBy(c, dropna=False) # For each term, count how many times it appears in the search space - space_multiplicity = g.sum(i < spacesize)[1] + + # since we reuse (i < spacesize)[g.permutation] later, we call sum aggregation manually + less_than = (i < spacesize)[g.permutation] + repMsg = generic_msg( + cmd="segmentedReduction", + args={ + "values": less_than, + "segments": g.segments, + "op": "sum", + "skip_nan": True, + "ddof": 1, + }, + ) + + space_multiplicity = create_pdarray(repMsg) has_duplicates = (space_multiplicity > 1).any() # handle duplicate terms in space if has_duplicates: @@ -243,28 +234,13 @@ def find(query, space, all_occurrences=False, remove_missing=False): from arkouda.segarray import SegArray - # use segmented mink to select space_multiplicity number of elements - # and create a segarray which contains all the indices - # in our query space, instead of just the min for each segment - - # only calculate where to place the negatives if remove_missing is false - negative_at = "" if remove_missing else space_multiplicity == 0 - repMsg = generic_msg( - cmd="segmentedExtremaK", - args={ - "vals": i[g.permutation], - "segs": g.segments, - "segLens": g.size()[1], - "kArray": space_multiplicity, - "isMin": True, - "removeMissing": remove_missing, - "negativeAt": negative_at, - }, - ) - min_k_vals = create_pdarray(repMsg) + # create a segarray which contains all the indices from query + # in our search space, instead of just the min for each segment + + # im not completely convinced there's not a better way to get this given the + # amount of structure but this is not the bottleneck of the computation anymore + min_k_vals = i[g.permutation][less_than] seg_idx = g.broadcast(arange(g.segments.size))[i >= spacesize] - if not remove_missing: - space_multiplicity += negative_at min_k_segs = cumsum(space_multiplicity) - space_multiplicity sa = SegArray(min_k_segs, min_k_vals) return sa[seg_idx] diff --git a/arkouda/pdarraysetops.py b/arkouda/pdarraysetops.py index 18e02e663f..8adc5eb94c 100644 --- a/arkouda/pdarraysetops.py +++ b/arkouda/pdarraysetops.py @@ -275,7 +275,6 @@ def indexof1d(query: groupable, space: groupable) -> pdarray: RuntimeError Raised if the dtype of either array is not supported """ - # from arkouda.alignment import find as akfind from arkouda.categorical import Categorical as Categorical_ if isinstance(query, (pdarray, Strings, Categorical_)): @@ -284,15 +283,16 @@ def indexof1d(query: groupable, space: groupable) -> pdarray: elif isinstance(query, pdarray) and not isinstance(space, pdarray): raise TypeError("If keys is pdarray, arr must also be pdarray") - repMsg = generic_msg( - cmd="indexof1d", - args={"keys": query, "arr": space}, - ) - return create_pdarray(cast(str, repMsg)) - - # TODO see issue #3229 reverted back to old implementation until we can investigate - # found = akfind(query, space, all_occurrences=True, remove_missing=True) - # return found if isinstance(found, pdarray) else found.values + # repMsg = generic_msg( + # cmd="indexof1d", + # args={"keys": query, "arr": space}, + # ) + # return create_pdarray(cast(str, repMsg)) + + from arkouda.alignment import find as akfind + + found = akfind(query, space, all_occurrences=True, remove_missing=True) + return found if isinstance(found, pdarray) else found.values # fmt: off