Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trying associated types #44

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions examples/src/ann_bench.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![deny(clippy::all)]
use hora::core;
use hora::core::ann_index::ANNIndex;
use hora::core::node::MemoryNode;
use std::collections::HashSet;
use std::time::SystemTime;

Expand Down Expand Up @@ -79,9 +80,9 @@ fn bench_ssg<E: core::node::FloatElement>(
let mut metrics_stats: Vec<StatMetrics> = Vec::new();
for params in params_set.iter() {
println!("start params {:?}", params);
let mut ssg_idx = Box::new(hora::index::ssg_idx::SSGIndex::<E, usize>::new(
dimension, params,
));
let mut ssg_idx = Box::new(
hora::index::ssg_idx::SSGIndex::<E, MemoryNode<E, usize>>::new(dimension, params),
);
make_idx_baseline(train, &mut ssg_idx);
metrics_stats.push(bench_calc(ssg_idx, test, neighbors));
println!("finish params {:?}", params);
Expand Down Expand Up @@ -130,9 +131,10 @@ fn bench_hnsw<E: core::node::FloatElement>(

let mut metrics_stats: Vec<StatMetrics> = Vec::new();
for params in params_set.iter() {
let mut hnsw_idx = Box::new(hora::index::hnsw_idx::HNSWIndex::<E, usize>::new(
dimension, params,
));
let mut hnsw_idx = Box::new(hora::index::hnsw_idx::HNSWIndex::<
usize,
MemoryNode<E, usize>,
>::new(dimension, params));
make_idx_baseline(train, &mut hnsw_idx);
metrics_stats.push(bench_calc(hnsw_idx, test, neighbors));
println!("finish params {:?}", params);
Expand Down Expand Up @@ -164,9 +166,11 @@ fn bench_ivfpq<E: core::node::FloatElement>(

let mut metrics_stats: Vec<StatMetrics> = Vec::new();
for params in params_set.iter() {
let mut ivfpq_idx = Box::new(hora::index::pq_idx::IVFPQIndex::<E, usize>::new(
dimension, params,
));
let mut ivfpq_idx = Box::new(hora::index::pq_idx::IVFPQIndex::<
E,
usize,
MemoryNode<E, usize>,
>::new(dimension, params));
make_idx_baseline(train, &mut ivfpq_idx);
metrics_stats.push(bench_calc(ivfpq_idx, test, neighbors));
println!("finish params {:?}", params);
Expand All @@ -184,7 +188,7 @@ fn bench_ivfpq<E: core::node::FloatElement>(
}
}

fn bench_calc<E: core::node::FloatElement, T: ANNIndex<E, usize> + ?Sized>(
fn bench_calc<E: core::node::FloatElement, T: ANNIndex<E, usize, MemoryNode<E, usize>> + ?Sized>(
ann_idx: Box<T>,
test: &Vec<Vec<E>>,
neighbors: &Vec<HashSet<usize>>,
Expand Down Expand Up @@ -223,17 +227,17 @@ fn bench_calc<E: core::node::FloatElement, T: ANNIndex<E, usize> + ?Sized>(
}
}

fn make_idx_baseline<E: core::node::FloatElement, T: ANNIndex<E, usize> + ?Sized>(
fn make_idx_baseline<
E: core::node::FloatElement,
T: ANNIndex<E, usize, MemoryNode<E, usize>> + ?Sized,
>(
embs: &Vec<Vec<E>>,
idx: &mut Box<T>,
) {
let start = SystemTime::now();
for i in 0..embs.len() {
idx.add_node(&core::node::Node::<E, usize>::new_with_idx(
embs[i].as_slice(),
i,
))
.unwrap();
idx.add_node(&core::node::Node::new_with_idx(embs[i].as_slice(), i))
.unwrap();
}
idx.build(core::metrics::Metric::DotProduct).unwrap();
let since_start = SystemTime::now()
Expand Down
3 changes: 2 additions & 1 deletion examples/src/demo.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use hora::core::ann_index::ANNIndex;
use hora::core::node::MemoryNode;
use rand::{thread_rng, Rng};
use rand_distr::{Distribution, Normal};

Expand All @@ -18,7 +19,7 @@ pub fn demo() {
}

// init index
let mut index = hora::index::hnsw_idx::HNSWIndex::<f32, usize>::new(
let mut index = hora::index::hnsw_idx::HNSWIndex::<usize, MemoryNode<f32, usize>>::new(
dimension,
&hora::index::hnsw_params::HNSWParams::<f32>::default(),
);
Expand Down
17 changes: 10 additions & 7 deletions src/core/ann_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use serde::de::DeserializeOwned;
/// ```
///

pub trait ANNIndex<E: node::FloatElement, T: node::IdxType>: Send + Sync {
pub trait ANNIndex<E: node::FloatElement, T: node::IdxType, N: node::Node<E = E, T = T>>:
Send + Sync
{
/// build up the ANN index
///
/// build up index with all node which have add into before, it will cost some time, and the time it cost depends on the algorithm
Expand All @@ -33,13 +35,13 @@ pub trait ANNIndex<E: node::FloatElement, T: node::IdxType>: Send + Sync {
///
/// it will allocate a space in the heap(Vector), and init a `Node`
/// return `Err(&'static str)` if there is something wrong with the adding process, and the `static str` is the debug reason
fn add_node(&mut self, item: &node::Node<E, T>) -> Result<(), &'static str>;
fn add_node(&mut self, item: &N) -> Result<(), &'static str>;

/// add node
///
/// call `add_node()` internal
fn add(&mut self, vs: &[E], idx: T) -> Result<(), &'static str> {
self.add_node(&node::Node::new_with_idx(vs, idx))
self.add_node(&N::new_with_idx(vs, idx))
}

/// add multiple node one time
Expand All @@ -50,7 +52,7 @@ pub trait ANNIndex<E: node::FloatElement, T: node::IdxType>: Send + Sync {
return Err("vector's size is different with index");
}
for idx in 0..vss.len() {
let n = node::Node::new_with_idx(vss[idx], indices[idx].clone());
let n = N::new_with_idx(vss[idx], indices[idx].clone());
if let Err(err) = self.add_node(&n) {
return Err(err);
}
Expand All @@ -71,14 +73,14 @@ pub trait ANNIndex<E: node::FloatElement, T: node::IdxType>: Send + Sync {
}

/// search for k nearest neighbors node internal method
fn node_search_k(&self, item: &node::Node<E, T>, k: usize) -> Vec<(node::Node<E, T>, E)>;
fn node_search_k(&self, item: &N, k: usize) -> Vec<(N, E)>;

/// search for k nearest neighbors and return full info
///
/// it will return the all node's info including the original vectors, and the metric distance
///
/// it require the item is the slice with the same dimension with index dimension, otherwise it will panic
fn search_nodes(&self, item: &[E], k: usize) -> Vec<(node::Node<E, T>, E)> {
fn search_nodes(&self, item: &[E], k: usize) -> Vec<(N, E)> {
assert_eq!(item.len(), self.dimension());
self.node_search_k(&node::Node::new(item), k)
}
Expand Down Expand Up @@ -141,7 +143,8 @@ pub trait ANNIndex<E: node::FloatElement, T: node::IdxType>: Send + Sync {
pub trait SerializableIndex<
E: node::FloatElement + DeserializeOwned,
T: node::IdxType + DeserializeOwned,
>: Send + Sync + ANNIndex<E, T>
N: node::Node<E = E, T = T>,
>: Send + Sync + ANNIndex<E, T, N>
{
/// load file with path
fn load(_path: &str) -> Result<Self, &'static str>
Expand Down
10 changes: 5 additions & 5 deletions src/core/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,10 @@ impl<E: node::FloatElement> Kmeans<E> {
}
}

pub fn general_kmeans<E: node::FloatElement, T: node::IdxType>(
pub fn general_kmeans<E: node::FloatElement, T: node::IdxType, N: node::Node<E = E, T = T>>(
k: usize,
epoch: usize,
nodes: &[Box<node::Node<E, T>>],
nodes: &[Box<N>],
mt: metrics::Metric,
) -> Vec<usize> {
if nodes.is_empty() {
Expand All @@ -239,7 +239,7 @@ pub fn general_kmeans<E: node::FloatElement, T: node::IdxType>(
let mut idx = 0;
let mut distance = E::max_value();
for (i, _item) in means.iter().enumerate() {
let _distance = node.metric(&means[i], mt).unwrap();
let _distance = node.metric(&**means[i], mt).unwrap();
if _distance < distance {
idx = i;
distance = _distance;
Expand Down Expand Up @@ -277,7 +277,7 @@ pub fn general_kmeans<E: node::FloatElement, T: node::IdxType>(
let mut mean_idx = 0;
let mut mean_distance = E::max_value();
nodes.iter().zip(0..nodes.len()).for_each(|(node, i)| {
let distance = node.metric(mean, mt).unwrap();
let distance = node.metric(&***mean, mt).unwrap();
if distance < mean_distance {
mean_idx = i;
mean_distance = distance;
Expand Down Expand Up @@ -347,7 +347,7 @@ mod tests {
.map(|x| x.iter().map(|p| *p as f32).collect())
.collect();

let nodes: Vec<Box<node::Node<f32, usize>>> = ns
let nodes: Vec<Box<node::MemoryNode<f32, usize>>> = ns
.iter()
.zip(0..ns.len())
.map(|(vs, idx)| Box::new(node::Node::new_with_idx(vs, idx)))
Expand Down
62 changes: 44 additions & 18 deletions src/core/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,45 @@ to_idx_type!(u32);
to_idx_type!(u64);
to_idx_type!(u128);

pub trait Node: Clone + Send + Sync + Serialize + Default {
type E: PartialOrd;
type T;
fn new(vectors: &[Self::E]) -> Self;
fn new_with_idx(vectors: &[Self::E], id: Self::T) -> Self;
fn metric(
&self,
other: &impl Node<E = Self::E, T = Self::T>,
t: metrics::Metric,
) -> Result<Self::E, &'static str>;
fn vectors(&self) -> &Vec<Self::E>;
fn mut_vectors(&mut self) -> &mut Vec<Self::E>;
fn set_vectors(&mut self, v: &[Self::E]);
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
fn idx(&self) -> &Option<Self::T>;
fn set_idx(&mut self, id: Self::T);
fn valid_elements(vectors: &[Self::E]) -> bool;
}

/// Node is the main container for the point in the space
///
/// it contains a array of `FloatElement` and a index
///
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct Node<E: FloatElement, T: IdxType> {
pub struct MemoryNode<E: FloatElement, T: IdxType> {
vectors: Vec<E>,
idx: Option<T>, // data id, it can be any type;
}

impl<E: FloatElement, T: IdxType> Node<E, T> {
impl<E: FloatElement, T: IdxType> Node for MemoryNode<E, T> {
type E = E;
type T = T;
/// new without idx
///
/// new a point without a idx
pub fn new(vectors: &[E]) -> Node<E, T> {
Node::<E, T>::valid_elements(vectors);
Node {
fn new(vectors: &[E]) -> MemoryNode<E, T> {
MemoryNode::<E, T>::valid_elements(vectors);
MemoryNode {
vectors: vectors.to_vec(),
idx: Option::None,
}
Expand All @@ -111,43 +133,47 @@ impl<E: FloatElement, T: IdxType> Node<E, T> {
/// new with idx
///
/// new a point with a idx
pub fn new_with_idx(vectors: &[E], id: T) -> Node<E, T> {
let mut n = Node::new(vectors);
fn new_with_idx(vectors: &[E], id: T) -> MemoryNode<E, T> {
let mut n = MemoryNode::new(vectors);
n.set_idx(id);
n
}

/// calculate the point distance
pub fn metric(&self, other: &Node<E, T>, t: metrics::Metric) -> Result<E, &'static str> {
metrics::metric(&self.vectors, &other.vectors, t)
fn metric(
&self,
other: &impl Node<E = E, T = T>,
t: metrics::Metric,
) -> Result<E, &'static str> {
metrics::metric(&self.vectors, &*other.vectors(), t)
}

// return internal embeddings
pub fn vectors(&self) -> &Vec<E> {
fn vectors(&self) -> &Vec<E> {
&self.vectors
}

// return mut internal embeddings
pub fn mut_vectors(&mut self) -> &mut Vec<E> {
fn mut_vectors(&mut self) -> &mut Vec<E> {
&mut self.vectors
}

// set internal embeddings
pub fn set_vectors(&mut self, v: &[E]) {
fn set_vectors(&mut self, v: &[E]) {
self.vectors = v.to_vec();
}

// internal embeddings length
pub fn len(&self) -> usize {
fn len(&self) -> usize {
self.vectors.len()
}

pub fn is_empty(&self) -> bool {
fn is_empty(&self) -> bool {
self.vectors.is_empty()
}

// return node's idx
pub fn idx(&self) -> &Option<T> {
fn idx(&self) -> &Option<T> {
&self.idx
}

Expand All @@ -166,7 +192,7 @@ impl<E: FloatElement, T: IdxType> Node<E, T> {
}
}

impl<E: FloatElement, T: IdxType> core::fmt::Display for Node<E, T> {
impl<E: FloatElement, T: IdxType> core::fmt::Display for MemoryNode<E, T> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "(key: {:#?}, vectors: {:#?})", self.idx, self.vectors)
}
Expand All @@ -180,7 +206,7 @@ fn node_test() {
// f64
let v = vec![1.0, 1.0];
let v2 = vec![2.0, 2.0];
let n = Node::<f64, usize>::new(&v);
let n2 = Node::<f64, usize>::new(&v2);
let n: MemoryNode<f64, usize> = Node::new(&v);
let n2: MemoryNode<f64, usize> = Node::new(&v2);
assert_eq!(n.metric(&n2, metrics::Metric::Manhattan).unwrap(), 2.0);
}
Loading