Skip to content

Commit

Permalink
Closes #1912: Updates to search_interval (#1913)
Browse files Browse the repository at this point in the history
This PR (closes #1912):
- Changes `hierarchical` to True by default
- Addresses bug on boundaries in `non_overlapping` check
- Adds `hierarchical` to `interval_lookup` which is False by default

Co-authored-by: Pierce Hayes <pierce314159@users.noreply.github.com>
  • Loading branch information
stress-tess and Pierce Hayes authored Nov 17, 2022
1 parent 77c7bef commit cdeab05
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 27 deletions.
40 changes: 21 additions & 19 deletions arkouda/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def in1d_intervals(vals, intervals, symmetric=False):
return found


def search_intervals(vals, intervals, tiebreak=None, hierarchical=False):
def search_intervals(vals, intervals, tiebreak=None, hierarchical=True):
"""
Given an array of query vals and non-overlapping, closed intervals, return
the index of the best (see tiebreak) interval containing each query value,
Expand Down Expand Up @@ -414,27 +414,28 @@ def search_intervals(vals, intervals, tiebreak=None, hierarchical=False):
bounds_okay = True
break
needtocheck &= lo == hi
# check non_overlapping
left = high[0][:-1]
right = low[0][1:]
not_overlapping = True
if (left <= right).any():
not_overlapping = False
else:
boundary = left != right
for lo, hi in zip(low[1:], high[1:]):
left = hi[:-1]
right = lo[1:]
_ = left <= right
if not (_ | boundary).all():
not_overlapping = False
break
boundary = boundary | (left != right)
else:
bounds_okay = all((hi >= lo).all() for hi, lo in zip(high, low))

if not bounds_okay:
raise ValueError("Upper bounds must be greater than lower bounds")

left = high[0][:-1]
right = low[0][1:]
not_overlapping = True
if (left < right).any():
not_overlapping = False
else:
boundary = left != right
for lo, hi in zip(low[1:], high[1:]):
left = hi[:-1]
right = lo[1:]
if not ((left <= right) | boundary).all():
not_overlapping = False
break
boundary = boundary | (left != right)

perm = coargsort([concatenate((lo, va, hi)) for lo, va, hi in zip(low, vals, high)])

if singleton or (isinstance(vals, Sequence) and hierarchical):
Expand Down Expand Up @@ -569,13 +570,14 @@ def is_cosorted(arrays):
for array in arrays[1:]:
left = array[:-1]
right = array[1:]
if not ((left <= right) | boundary).all():
_ = left <= right
if not (_ | boundary).all():
return False
boundary = boundary | (left != right)
return True


def interval_lookup(keys, values, arguments, fillvalue=-1, tiebreak=None):
def interval_lookup(keys, values, arguments, fillvalue=-1, tiebreak=None, hierarchical=False):
"""
Apply a function defined over intervals to an array of arguments.
Expand Down Expand Up @@ -605,7 +607,7 @@ def interval_lookup(keys, values, arguments, fillvalue=-1, tiebreak=None):
if isinstance(values, Categorical):
codes = interval_lookup(keys, values.codes, arguments, fillvalue=values._NAcode)
return Categorical.from_codes(codes, values.categories, NAvalue=values.NAvalue)
idx = search_intervals(arguments, keys, tiebreak=tiebreak)
idx = search_intervals(arguments, keys, tiebreak=tiebreak, hierarchical=hierarchical)
arguments_size = arguments.size if isinstance(arguments, pdarray) else arguments[0].size
res = zeros(arguments_size, dtype=values.dtype)
if fillvalue is not None:
Expand Down
16 changes: 8 additions & 8 deletions tests/alignment_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ def test_multi_array_search_interval(self):
ends = (ak.array([4, 14, 24]), ak.array([4, 14, 24]))
vals = (ak.array([3, 13, 23]), ak.array([23, 13, 3]))
ans = [-1, 1, -1]
self.assertListEqual(ans, ak.search_intervals(vals, (starts, ends)).to_list())
self.assertListEqual(ans, ak.search_intervals(vals, (starts, ends), hierarchical=False).to_list())
self.assertListEqual(ans, ak.interval_lookup((starts, ends), ak.arange(3), vals).to_list())

vals = (ak.array([23, 13, 3]), ak.array([23, 13, 3]))
ans = [2, 1, 0]
self.assertListEqual(ans, ak.search_intervals(vals, (starts, ends)).to_list())
self.assertListEqual(ans, ak.search_intervals(vals, (starts, ends), hierarchical=False).to_list())
self.assertListEqual(ans, ak.interval_lookup((starts, ends), ak.arange(3), vals).to_list())

vals = (ak.array([23, 13, 33]), ak.array([23, 13, 3]))
ans = [2, 1, -1]
self.assertListEqual(ans, ak.search_intervals(vals, (starts, ends)).to_list())
self.assertListEqual(ans, ak.search_intervals(vals, (starts, ends), hierarchical=False).to_list())
self.assertListEqual(ans, ak.interval_lookup((starts, ends), ak.arange(3), vals).to_list())

# test hierarchical flag
Expand All @@ -55,11 +55,11 @@ def test_multi_array_search_interval(self):
vals = (ak.array([0, 0, 2, 5, 5, 6, 6, 9]), ak.array([0, 20, 1, 5, 15, 0, 12, 30]))

self.assertListEqual(
ak.search_intervals(vals, (starts, ends)).to_list(), [0, -1, 0, 0, 1, -1, 1, -1]
ak.search_intervals(vals, (starts, ends), hierarchical=False).to_list(),
[0, -1, 0, 0, 1, -1, 1, -1],
)
self.assertListEqual(
ak.search_intervals(vals, (starts, ends), hierarchical=True).to_list(),
[0, 0, 0, 0, 1, 1, 1, -1],
ak.search_intervals(vals, (starts, ends)).to_list(), [0, 0, 0, 0, 1, 1, 1, -1]
)

def test_search_interval_nonunique(self):
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_representative_cases(self):
tiebreak_smallest = (y1 - y0) * (x1 - x0)
first_answer = [-1, -1, 0, 0, -1, 0, 2, 0, -1, 0, 0, 3, -1]
smallest_answer = [-1, -1, 0, 2, -1, 2, 2, 1, -1, 0, 0, 3, -1]
first_result = ak.search_intervals(values, intervals)
first_result = ak.search_intervals(values, intervals, hierarchical=False)
self.assertListEqual(first_result.to_list(), first_answer)
smallest_result = ak.search_intervals(values, intervals, tiebreak=tiebreak_smallest)
smallest_result = ak.search_intervals(values, intervals, tiebreak=tiebreak_smallest, hierarchical=False)
self.assertListEqual(smallest_result.to_list(), smallest_answer)

0 comments on commit cdeab05

Please sign in to comment.