Skip to content

Commit

Permalink
Fix the building reference index when ref_index is None in DISE (#…
Browse files Browse the repository at this point in the history
…264)

* Fix the return type when calculating medoid in `get_initial_selection`

* Add test case for ref_index is None

* Remove redundant checking of ref_index
  • Loading branch information
FanwangM authored Sep 9, 2024
1 parent 198e004 commit 19c8cf7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
4 changes: 1 addition & 3 deletions selector/methods/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,6 @@ def __init__(
"""
self.r0 = r0
self.r = r0
if ref_index is not None and ref_index < 0:
raise ValueError(f"ref_index must be a non-negative integer, got {ref_index}.")
self.ref_index = ref_index
self.tol = tol
self.n_iter = n_iter
Expand Down Expand Up @@ -677,7 +675,7 @@ def get_initial_selection(x=None, x_dist=None, ref_index=None, fun_dist=None):
if x_dist is None:
x_dist = fun_dist(x)
# calculate the medoid center
initial_selections = [np.argmin(np.sum(x_dist, axis=0))]
initial_selections = [int(np.argmin(np.sum(x_dist, axis=0)))]

# the length of the distance matrix is the number of samples
if x_dist is not None:
Expand Down
10 changes: 10 additions & 0 deletions selector/methods/tests/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,16 @@ def test_directed_sphere_same_number_of_pts():
assert_equal(collector.r, 1)


def test_directed_sphere_same_number_of_pts_None():
"""Test DirectSphereExclusion with `size` = number of points in dataset with the ref_index None."""
# None as the reference point
x = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4]])
collector = DISE(r0=1, tol=0, ref_index=None)
selected = collector.select(x, size=3)
assert_equal(selected, [2, 0, 4])
assert_equal(collector.r, 1)


def test_directed_sphere_exclusion_select_more_number_of_pts():
"""Test DirectSphereExclusion on points on the line with `size` < number of points in dataset."""
# (0,0) as the reference point
Expand Down

0 comments on commit 19c8cf7

Please sign in to comment.