Skip to content

Commit

Permalink
Update package and python module
Browse files Browse the repository at this point in the history
  • Loading branch information
jinlow committed Feb 24, 2024
1 parent df36dc5 commit 7e1136b
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 10 deletions.
13 changes: 13 additions & 0 deletions python/python/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

class KDTree:
"""Simple KDTree Implementation"""

def __init__(
self, points: list[tuple[str | int | float, list[float]]], min_points: int = 1
): ...
def get_nearest_neighbors(
self, point: list[float], k: int = 1
) -> list[tuple[str | int | float, float]]:
"""Get k nearest neighbors."""
...
1 change: 0 additions & 1 deletion python/python/nearest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .nearest import *


__doc__ = nearest.__doc__
if hasattr(nearest, "__all__"):
__all__ = nearest.__all__
Empty file added python/python/nearest/py.typed
Empty file.
18 changes: 15 additions & 3 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,23 @@ pub struct KDTree {

#[pymethods]
impl KDTree {
/// Instantiate a new KDTree Object.
#[new]
fn new(records: Vec<(DataType, Vec<f32>)>) -> Self {
#[pyo3(signature = (records, min_points=30))]
fn new(records: Vec<(DataType, Vec<f32>)>, min_points: usize) -> Self {
KDTree {
tree: nearest_rust::KDTree::from_iter(
records.into_iter().map(|(d, p)| nearest_rust::Data::new(d, p)),
records
.into_iter()
.map(|(d, p)| nearest_rust::Data::new(d, p)),
min_points,
)
.unwrap(),
}
}

/// Get the K nearest neighbors to a point.
#[pyo3(signature = (point, k=1))]
pub fn get_nearest_neighbors(
&self,
py: Python,
Expand All @@ -41,7 +49,11 @@ impl KDTree {
let raw_point = nearest_rust::Point::new(point);
Ok(self
.tree
.get_nearest_neighbors(&raw_point, k, &nearest_rust::SquaredEuclideanDistance::default())
.get_nearest_neighbors(
&raw_point,
k,
&nearest_rust::SquaredEuclideanDistance::default(),
)
.iter()
.map(|n| match &n.data {
DataType::Str(v) => (n.distance, v.into_py(py)),
Expand Down
18 changes: 12 additions & 6 deletions src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,10 @@ fn build_tree<T: Clone>(
data_location: usize,
depth: usize,
point_len: usize,
min_points: usize,
) -> NodeOrDataPointer {
// Only can split further if there is at least 3 records
if data.len() < 3 {
if (data.len() < min_points) || (data.len() < 3) {
return NodeOrDataPointer::Data((data_location, (data_location + data.len())));
}
let axis = depth % point_len;
Expand All @@ -160,24 +161,29 @@ fn build_tree<T: Clone>(
data_location,
depth + 1,
point_len,
min_points,
)),
right: Box::new(build_tree(
&mut data[(median + 1)..],
data_location + median + 1,
depth + 1,
point_len,
min_points,
)),
};
return NodeOrDataPointer::Node(node);
}

impl<T: Clone> KDTree<T> {
pub fn from_iter<I: Iterator<Item = Data<T>>>(data: I) -> Result<Self, NearestError> {
Self::from_vec(data.collect())
pub fn from_iter<I: Iterator<Item = Data<T>>>(
data: I,
min_points: usize,
) -> Result<Self, NearestError> {
Self::from_vec(data.collect(), min_points)
}
pub fn from_vec(mut data: Vec<Data<T>>) -> Result<Self, NearestError> {
pub fn from_vec(mut data: Vec<Data<T>>, min_points: usize) -> Result<Self, NearestError> {
let point_len = data[0].point.shape();
let root_node = build_tree(&mut data, 0, 0, point_len);
let root_node = build_tree(&mut data, 0, 0, point_len, min_points);
Ok(KDTree {
root_node,
data,
Expand Down Expand Up @@ -312,7 +318,7 @@ mod tests {
Data::new("Tokyo", vec![35.690, 139.692]),
];
let data_len = data.len();
let tree = KDTree::from_vec(data).unwrap();
let tree = KDTree::from_vec(data, 1).unwrap();
let mut stack = vec![tree.get_root_node().unwrap()];
let mut found_data = vec![
tree.get_root_node().unwrap().data_pointer
Expand Down

0 comments on commit 7e1136b

Please sign in to comment.