From 19c8cf7416b0293fac87fbe6aafd907d1d44d2e8 Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Mon, 9 Sep 2024 11:55:13 -0400 Subject: [PATCH] Fix the building reference index when `ref_index` is None in `DISE` (#264) * Fix the return type when calculating medoid in `get_initial_selection` * Add test case for ref_index is None * Remove redundant checking of ref_index --- selector/methods/distance.py | 4 +--- selector/methods/tests/test_distance.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/selector/methods/distance.py b/selector/methods/distance.py index 4c25403c..d3f9b6ee 100644 --- a/selector/methods/distance.py +++ b/selector/methods/distance.py @@ -507,8 +507,6 @@ def __init__( """ self.r0 = r0 self.r = r0 - if ref_index is not None and ref_index < 0: - raise ValueError(f"ref_index must be a non-negative integer, got {ref_index}.") self.ref_index = ref_index self.tol = tol self.n_iter = n_iter @@ -677,7 +675,7 @@ def get_initial_selection(x=None, x_dist=None, ref_index=None, fun_dist=None): if x_dist is None: x_dist = fun_dist(x) # calculate the medoid center - initial_selections = [np.argmin(np.sum(x_dist, axis=0))] + initial_selections = [int(np.argmin(np.sum(x_dist, axis=0)))] # the length of the distance matrix is the number of samples if x_dist is not None: diff --git a/selector/methods/tests/test_distance.py b/selector/methods/tests/test_distance.py index 022c7e90..a65620da 100644 --- a/selector/methods/tests/test_distance.py +++ b/selector/methods/tests/test_distance.py @@ -301,6 +301,16 @@ def test_directed_sphere_same_number_of_pts(): assert_equal(collector.r, 1) +def test_directed_sphere_same_number_of_pts_None(): + """Test DirectSphereExclusion with `size` = number of points in dataset with the ref_index None.""" + # None as the reference point + x = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4]]) + collector = DISE(r0=1, tol=0, ref_index=None) + selected = collector.select(x, size=3) + assert_equal(selected, [2, 0, 4]) + assert_equal(collector.r, 1) + + def test_directed_sphere_exclusion_select_more_number_of_pts(): """Test DirectSphereExclusion on points on the line with `size` < number of points in dataset.""" # (0,0) as the reference point