Skip to content

Commit

Permalink
Merge pull request #118 from alnaba1/optisim
Browse files Browse the repository at this point in the history
Improve Optisim Performance
  • Loading branch information
FanwangM committed Aug 31, 2022
2 parents ed6516d + b8e914c commit 79a46f3
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 41 deletions.
61 changes: 57 additions & 4 deletions DiverseSelector/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,16 @@ def select_from_cluster(self, arr, num_selected, cluster_ids=None):


class KDTreeBase(SelectionBase, ABC):
"""Base class for KDTree based subset selection."""
"""Base class for KDTree based subset selection.
Adapted from https://johnlekberg.com/blog/2020-04-17-kd-tree.html
"""

def __int__(self):
"""Initializing class."""
self.func_distance = lambda x, y: sum((i - j) ** 2 for i, j in zip(x, y))
self.BT = collections.namedtuple("BT", ["value", "index", "left", "right"])
self.NNRecord = collections.namedtuple("NNRecord", ["point", "distance"])

def _kdtree(self, arr):
"""Construct a k-d tree from an iterable of points.
Expand Down Expand Up @@ -169,7 +173,7 @@ def build(points, depth, old_indices=None):
kdtree = build(points=arr, depth=0)
return kdtree

def _find_nearest_neighbor(self, kdtree, point, threshold):
def _find_nearest_neighbor(self, kdtree, point, threshold, sort=True):
"""
Find the nearest neighbors in a k-d tree for a point.
Expand All @@ -181,6 +185,8 @@ def _find_nearest_neighbor(self, kdtree, point, threshold):
Query point for search.
threshold: float
The boundary used to mark all the points whose distance is within the threshold.
sort: boolean
Whether the results should be sorted based on lowest distance or not.
Returns
-------
Expand Down Expand Up @@ -213,11 +219,58 @@ def search(tree, depth):
search(tree=away, depth=depth + 1)

search(tree=kdtree, depth=0)
to_eliminate.sort()
to_eliminate.pop(0)
to_eliminate = [index for dist, index in to_eliminate]
if sort:
to_eliminate.sort()
return to_eliminate

def _nearest_neighbor(self, kdtree, point):
"""
Find the nearest neighbors in a k-d tree for a point.
Parameters
----------
kdtree: collections.namedtuple
KDTree organizing coordinates.
point: list
Query point for search.
threshold: float
The boundary used to mark all the points whose distance is within the threshold.
Returns
-------
to_eliminate: list
A list containing all the indices of points too close to the newly selected point.
"""
k = len(point)
best = None

def search(tree, depth):
# Recursively search through the k-d tree to find the
# nearest neighbor.
nonlocal best

if tree is None:
return

distance = self.func_distance(tree.value, point)
if best is None or distance < best.distance:
best = self.NNRecord(point=tree.value, distance=distance)

axis = depth % k
diff = point[axis] - tree.value[axis]
if diff <= 0:
close, away = tree.left, tree.right
else:
close, away = tree.right, tree.left

search(tree=close, depth=depth + 1)
if diff < best.distance:
search(tree=away, depth=depth + 1)

search(tree=kdtree, depth=0)
return best

def _eliminate(self, tree, point, threshold, num_eliminate, bv):
"""Eliminates points from being selected in future rounds.
Expand Down
70 changes: 35 additions & 35 deletions DiverseSelector/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,15 @@ def select_from_cluster(self, arr, num_selected, cluster_ids=None):
return selected


class OptiSim(SelectionBase):
class OptiSim(KDTreeBase):
"""Selecting compounds using OptiSim algorithm.
Initial point is chosen as medoid center. Points are randomly chosen and added to a subsample
if outside of radius r from all previously selected points, and discarded otherwise. Once k
number of points are added to the subsample, the point with the greatest minimum distance to
previously selected points is selected and the subsample is cleared and the process repeats.
Addapted from https://doi.org/10.1021/ci970282v
Adapted from https://doi.org/10.1021/ci970282v
"""

def __init__(
Expand Down Expand Up @@ -216,6 +216,8 @@ def __init__(
self.func_distance = func_distance
self.start_id = start_id
self.random_seed = random_seed
self.BT = collections.namedtuple("BT", ["value", "index", "left", "right"])
self.NNRecord = collections.namedtuple("NNRecord", ["point", "distance"])

def algorithm(self, arr) -> list:
"""
Expand All @@ -232,37 +234,35 @@ def algorithm(self, arr) -> list:
List of ids of selected molecules
"""
selected = [self.start_id]
recycling = []

candidates = np.delete(np.arange(0, len(arr)), selected + recycling)
subsample = {}
while len(candidates) > 0:
while len(subsample) < self.k:
if len(candidates) == 0:
if len(subsample) > 0:
break
return selected
rng = np.random.default_rng(seed=self.random_seed)
random_int = rng.integers(low=0, high=len(candidates), size=1)[0]
index_new = candidates[random_int]
distances = []
for selected_idx in selected:
data_point = arr[index_new]
selected_point = arr[selected_idx]
distance = self.func_distance(selected_point, data_point)
distances.append(distance)
min_dist = min(distances)
if min_dist > self.r:
subsample[index_new] = min_dist
else:
recycling.append(index_new)
candidates = np.delete(
np.arange(0, len(arr)),
selected + recycling + list(subsample.keys()),
)
selected.append(max(zip(subsample.values(), subsample.keys()))[1])
candidates = np.delete(np.arange(0, len(arr)), selected + recycling)
subsample = {}
tree = self._kdtree(arr)
rng = np.random.default_rng(seed=self.random_seed)
len_arr = len(arr)
bv = np.zeros(len_arr)
candidates = list(range(len_arr))
elim = self._find_nearest_neighbor(kdtree=tree, point=arr[self.start_id], threshold=self.r,
sort=False)
for idx in elim:
bv[idx] = 1
candidates = np.ma.array(candidates, mask=bv)
while len(candidates.compressed()) > 0:
try:
sublist = rng.choice(candidates.compressed(), size=self.k, replace=False)
except ValueError:
sublist = candidates.compressed()
newtree = self._kdtree(arr[selected])
best_dist = None
best_idx = None
for idx in sublist:
search = self._nearest_neighbor(newtree, arr[idx])
if best_dist is None or search.distance > best_dist:
best_dist = search.distance
best_idx = idx
selected.append(best_idx)
elim = self._find_nearest_neighbor(kdtree=tree, point=arr[best_idx], threshold=self.r,
sort=False)
for idx in elim:
bv[idx] = 1
candidates = np.ma.array(candidates, mask=bv)

return selected

Expand Down Expand Up @@ -744,7 +744,7 @@ def predict_radius(obj: Union[DirectedSphereExclusion, OptiSim], arr, num_select
bounds = [low, high]
count = 0
error = num_selected * obj.tolerance / 100
while (len(result) < num_selected - error or len(result) > num_selected + error) and count < 20:
while (len(result) < num_selected - error or len(result) > num_selected + error) and count < 10:
if bounds[1] is None:
rg = bounds[0] * 2
else:
Expand All @@ -756,7 +756,7 @@ def predict_radius(obj: Union[DirectedSphereExclusion, OptiSim], arr, num_select
else:
bounds[1] = rg
count += 1
if count == 10:
if count >= 10:
print(f"Optimal radius finder failed to converge, selected {len(result)} molecules instead "
f"of requested {num_selected}.")
obj.r = original_r
Expand Down
11 changes: 9 additions & 2 deletions DiverseSelector/test/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,19 @@ def test_optisim():
selector = OptiSim()
selected_ids = selector.select(arr=coords_cluster, num_selected=12, labels=class_labels_cluster)
# make sure all the selected indices are the same with expectation
assert_equal(selected_ids, [2, 85, 86, 59, 1, 50, 93, 68, 0, 11, 33, 46])
assert_equal(selected_ids, [2, 85, 86, 59, 1, 66, 50, 68, 0, 64, 83, 72])

selector = OptiSim()
selected_ids = selector.select(arr=coords, num_selected=12)
# make sure all the selected indices are the same with expectation
assert_equal(selected_ids, [0, 13, 21, 9, 8, 18, 39, 57, 65, 25, 6, 45])
assert_equal(selected_ids, [0, 8, 55, 37, 41, 13, 12, 42, 6, 30, 57, 76])

# tester to check if optisim gives same results as maxmin for k=>infinity
selector = OptiSim(start_id=85, k=999999)
selected_ids_optisim = selector.select(arr=coords, num_selected=12)
selector = MaxMin()
selected_ids_maxmin = selector.select(arr=arr_dist, num_selected=12)
assert_equal(selected_ids_optisim, selected_ids_maxmin)


def test_directedsphereexclusion():
Expand Down

0 comments on commit 79a46f3

Please sign in to comment.