From b0ce240fa74132ed8b5edf9d8c90792cfff1fa93 Mon Sep 17 00:00:00 2001 From: Gavin Mendel-Gleason Date: Sat, 27 May 2023 12:50:47 +0200 Subject: [PATCH 1/4] WIP: Making node generic --- src/core/ann_index.rs | 17 ++++---- src/core/kmeans.rs | 4 +- src/core/node.rs | 46 +++++++++++++-------- src/index/bruteforce_idx.rs | 25 +++++++----- src/index/hnsw_idx.rs | 45 +++++++++++---------- src/index/pq_idx.rs | 80 ++++++++++++++++++------------------- src/index/ssg_idx.rs | 27 ++++++++----- src/lib.rs | 11 +++-- 8 files changed, 142 insertions(+), 113 deletions(-) diff --git a/src/core/ann_index.rs b/src/core/ann_index.rs index 8439aba..8dec063 100644 --- a/src/core/ann_index.rs +++ b/src/core/ann_index.rs @@ -22,7 +22,9 @@ use serde::de::DeserializeOwned; /// ``` /// -pub trait ANNIndex: Send + Sync { +pub trait ANNIndex>: + Send + Sync +{ /// build up the ANN index /// /// build up index with all node which have add into before, it will cost some time, and the time it cost depends on the algorithm @@ -33,13 +35,13 @@ pub trait ANNIndex: Send + Sync { /// /// it will allocate a space in the heap(Vector), and init a `Node` /// return `Err(&'static str)` if there is something wrong with the adding process, and the `static str` is the debug reason - fn add_node(&mut self, item: &node::Node) -> Result<(), &'static str>; + fn add_node(&mut self, item: &impl node::Node) -> Result<(), &'static str>; /// add node /// /// call `add_node()` internal fn add(&mut self, vs: &[E], idx: T) -> Result<(), &'static str> { - self.add_node(&node::Node::new_with_idx(vs, idx)) + self.add_node(&N::new_with_idx(vs, idx)) } /// add multiple node one time @@ -50,7 +52,7 @@ pub trait ANNIndex: Send + Sync { return Err("vector's size is different with index"); } for idx in 0..vss.len() { - let n = node::Node::new_with_idx(vss[idx], indices[idx].clone()); + let n = N::new_with_idx(vss[idx], indices[idx].clone()); if let Err(err) = self.add_node(&n) { return Err(err); } @@ -71,14 +73,14 @@ pub trait ANNIndex: Send + Sync { } /// search for k nearest neighbors node internal method - fn node_search_k(&self, item: &node::Node, k: usize) -> Vec<(node::Node, E)>; + fn node_search_k(&self, item: &N, k: usize) -> Vec<(N, E)>; /// search for k nearest neighbors and return full info /// /// it will return the all node's info including the original vectors, and the metric distance /// /// it require the item is the slice with the same dimension with index dimension, otherwise it will panic - fn search_nodes(&self, item: &[E], k: usize) -> Vec<(node::Node, E)> { + fn search_nodes(&self, item: &[E], k: usize) -> Vec<(N, E)> { assert_eq!(item.len(), self.dimension()); self.node_search_k(&node::Node::new(item), k) } @@ -141,7 +143,8 @@ pub trait ANNIndex: Send + Sync { pub trait SerializableIndex< E: node::FloatElement + DeserializeOwned, T: node::IdxType + DeserializeOwned, ->: Send + Sync + ANNIndex + N: node::Node, +>: Send + Sync + ANNIndex { /// load file with path fn load(_path: &str) -> Result diff --git a/src/core/kmeans.rs b/src/core/kmeans.rs index 3e887dc..4b170b0 100644 --- a/src/core/kmeans.rs +++ b/src/core/kmeans.rs @@ -213,10 +213,10 @@ impl Kmeans { } } -pub fn general_kmeans( +pub fn general_kmeans>( k: usize, epoch: usize, - nodes: &[Box>], + nodes: &[Box], mt: metrics::Metric, ) -> Vec { if nodes.is_empty() { diff --git a/src/core/node.rs b/src/core/node.rs index 68ff518..e2fa8f7 100644 --- a/src/core/node.rs +++ b/src/core/node.rs @@ -86,23 +86,37 @@ to_idx_type!(u32); to_idx_type!(u64); to_idx_type!(u128); +pub trait Node: Send + Sync { + fn new(vectors: &[E]) -> Self; + fn new_with_index(vectors: &[E], id: T) -> Self; + fn metric(&self, other: impl Node, t: metrics::Metric) -> Result; + fn vectors(&self) -> Vec; + fn mut_vectors(&mut self) -> &mut Vec; + fn set_vectors(&mut self, v: &[E]); + fn len(&self) -> usize; + fn is_empty(&self) -> bool; + fn idx(&self) -> &Option; + fn set_idx(&mut self, id: T); + fn valid_elements(vectors: &[E]) -> bool; +} + /// Node is the main container for the point in the space /// /// it contains a array of `FloatElement` and a index /// #[derive(Clone, Debug, Default, Serialize, Deserialize)] -pub struct Node { +pub struct MemoryNode { vectors: Vec, idx: Option, // data id, it can be any type; } -impl Node { +impl Node for MemoryNode { /// new without idx /// /// new a point without a idx - pub fn new(vectors: &[E]) -> Node { - Node::::valid_elements(vectors); - Node { + fn new(vectors: &[E]) -> MemoryNode { + MemoryNode::::valid_elements(vectors); + MemoryNode { vectors: vectors.to_vec(), idx: Option::None, } @@ -111,43 +125,43 @@ impl Node { /// new with idx /// /// new a point with a idx - pub fn new_with_idx(vectors: &[E], id: T) -> Node { - let mut n = Node::new(vectors); + fn new_with_idx(vectors: &[E], id: T) -> MemoryNode { + let mut n = MemoryNode::new(vectors); n.set_idx(id); n } /// calculate the point distance - pub fn metric(&self, other: &Node, t: metrics::Metric) -> Result { - metrics::metric(&self.vectors, &other.vectors, t) + fn metric(&self, other: &impl Node, t: metrics::Metric) -> Result { + metrics::metric(&self.vectors, &other.vectors(), t) } // return internal embeddings - pub fn vectors(&self) -> &Vec { + fn vectors(&self) -> &Vec { &self.vectors } // return mut internal embeddings - pub fn mut_vectors(&mut self) -> &mut Vec { + fn mut_vectors(&mut self) -> &mut Vec { &mut self.vectors } // set internal embeddings - pub fn set_vectors(&mut self, v: &[E]) { + fn set_vectors(&mut self, v: &[E]) { self.vectors = v.to_vec(); } // internal embeddings length - pub fn len(&self) -> usize { + fn len(&self) -> usize { self.vectors.len() } - pub fn is_empty(&self) -> bool { + fn is_empty(&self) -> bool { self.vectors.is_empty() } // return node's idx - pub fn idx(&self) -> &Option { + fn idx(&self) -> &Option { &self.idx } @@ -166,7 +180,7 @@ impl Node { } } -impl core::fmt::Display for Node { +impl core::fmt::Display for MemoryNode { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { write!(f, "(key: {:#?}, vectors: {:#?})", self.idx, self.vectors) } diff --git a/src/index/bruteforce_idx.rs b/src/index/bruteforce_idx.rs index 795460d..3d563f7 100644 --- a/src/index/bruteforce_idx.rs +++ b/src/index/bruteforce_idx.rs @@ -13,16 +13,16 @@ use std::fs::File; use std::io::Write; #[derive(Debug, Serialize, Deserialize)] -pub struct BruteForceIndex { +pub struct BruteForceIndex> { #[serde(skip_serializing, skip_deserializing)] - nodes: Vec>>, - tmp_nodes: Vec>, // only use for serialization scene + nodes: Vec>, + tmp_nodes: Vec, // only use for serialization scene mt: metrics::Metric, dimension: usize, } -impl BruteForceIndex { - pub fn new(dimension: usize, _params: &BruteForceParams) -> BruteForceIndex { +impl> BruteForceIndex { + pub fn new(dimension: usize, _params: &BruteForceParams) -> BruteForceIndex { BruteForceIndex:: { nodes: Vec::new(), mt: metrics::Metric::Unknown, @@ -32,19 +32,21 @@ impl BruteForceIndex { } } -impl ann_index::ANNIndex for BruteForceIndex { +impl> ann_index::ANNIndex + for BruteForceIndex +{ fn build(&mut self, mt: metrics::Metric) -> Result<(), &'static str> { self.mt = mt; Result::Ok(()) } - fn add_node(&mut self, item: &node::Node) -> Result<(), &'static str> { + fn add_node(&mut self, item: &N) -> Result<(), &'static str> { self.nodes.push(Box::new(item.clone())); Result::Ok(()) } fn built(&self) -> bool { true } - fn node_search_k(&self, item: &node::Node, k: usize) -> Vec<(node::Node, E)> { + fn node_search_k(&self, item: &N, k: usize) -> Vec<(N, E)> { let mut heap = BinaryHeap::with_capacity(k + 1); self.nodes .iter() @@ -81,8 +83,11 @@ impl ann_index::ANNIndex for Brut } } -impl - ann_index::SerializableIndex for BruteForceIndex +impl< + E: node::FloatElement + DeserializeOwned, + T: node::IdxType + DeserializeOwned, + N: node::Node + DeserializeOwned, + > ann_index::SerializableIndex for BruteForceIndex { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); diff --git a/src/index/hnsw_idx.rs b/src/index/hnsw_idx.rs index 9649658..04f6d4b 100644 --- a/src/index/hnsw_idx.rs +++ b/src/index/hnsw_idx.rs @@ -21,7 +21,7 @@ use std::io::Write; use std::sync::RwLock; #[derive(Default, Debug, Serialize, Deserialize)] -pub struct HNSWIndex { +pub struct HNSWIndex> { _dimension: usize, // dimension _n_items: usize, // next item count _n_constructed_items: usize, @@ -35,7 +35,7 @@ pub struct HNSWIndex { #[serde(skip_serializing, skip_deserializing)] _id2neighbor0: Vec>>, //neigh_id at level 0 #[serde(skip_serializing, skip_deserializing)] - _nodes: Vec>>, // data saver + _nodes: Vec>, // data saver #[serde(skip_serializing, skip_deserializing)] _item2id: HashMap, //item_id to id in Hnsw _root_id: usize, //root of hnsw @@ -50,13 +50,13 @@ pub struct HNSWIndex { // use for serde _id2neighbor_tmp: Vec>>, _id2neighbor0_tmp: Vec>, - _nodes_tmp: Vec>, + _nodes_tmp: Vec, _item2id_tmp: Vec<(T, usize)>, _delete_ids_tmp: Vec, } -impl HNSWIndex { - pub fn new(dimension: usize, params: &HNSWParams) -> HNSWIndex { +impl> HNSWIndex { + pub fn new(dimension: usize, params: &HNSWParams) -> HNSWIndex { HNSWIndex { _dimension: dimension, _n_items: 0, @@ -236,11 +236,11 @@ impl HNSWIndex { self._has_removed && self._delete_ids.contains(&id) } - fn get_data(&self, id: usize) -> &node::Node { + fn get_data(&self, id: usize) -> &N { &self._nodes[id] } - fn get_distance_from_vec(&self, x: &node::Node, y: &node::Node) -> E { + fn get_distance_from_vec(&self, x: &N, y: &N) -> E { return metrics::metric(x.vectors(), y.vectors(), self.mt).unwrap(); } @@ -255,7 +255,7 @@ impl HNSWIndex { fn search_layer_with_candidate( &self, - search_data: &node::Node, + search_data: &N, sorted_candidates: &[Neighbor], visited_id: &mut FixedBitSet, level: usize, @@ -320,7 +320,7 @@ impl HNSWIndex { fn search_layer( &self, root: usize, - search_data: &node::Node, + search_data: &N, level: usize, ef: usize, has_deletion: bool, @@ -380,7 +380,7 @@ impl HNSWIndex { // fn search_layer_default( // &self, // root: usize, - // search_data: &node::Node, + // search_data: &N, // level: usize, // ) -> BinaryHeap> { // return self.search_layer(root, search_data, level, self._ef_build, false); @@ -388,7 +388,7 @@ impl HNSWIndex { fn search_knn( &self, - search_data: &node::Node, + search_data: &N, k: usize, ) -> Result>, &'static str> { let mut top_candidate: BinaryHeap> = BinaryHeap::new(); @@ -438,7 +438,7 @@ impl HNSWIndex { Ok(top_candidate) } - fn init_item(&mut self, data: &node::Node) -> usize { + fn init_item(&mut self, data: &N) -> usize { let cur_id = self._n_items; let mut cur_level = self.get_random_level(); if cur_id == 0 { @@ -475,7 +475,7 @@ impl HNSWIndex { Ok(()) } - fn add_item_not_constructed(&mut self, data: &node::Node) -> Result<(), &'static str> { + fn add_item_not_constructed(&mut self, data: &N) -> Result<(), &'static str> { if data.len() != self._dimension { return Err("dimension is different"); } @@ -495,7 +495,7 @@ impl HNSWIndex { Ok(()) } - fn add_single_item(&mut self, data: &node::Node) -> Result<(), &'static str> { + fn add_single_item(&mut self, data: &N) -> Result<(), &'static str> { //not support asysn if data.len() != self._dimension { return Err("dimension is different"); @@ -599,21 +599,23 @@ impl HNSWIndex { } } -impl ann_index::ANNIndex for HNSWIndex { +impl> ann_index::ANNIndex + for HNSWIndex +{ fn build(&mut self, mt: metrics::Metric) -> Result<(), &'static str> { self.mt = mt; self.batch_construct(mt) } - fn add_node(&mut self, item: &node::Node) -> Result<(), &'static str> { + fn add_node(&mut self, item: &N) -> Result<(), &'static str> { self.add_item_not_constructed(item) } fn built(&self) -> bool { true } - fn node_search_k(&self, item: &node::Node, k: usize) -> Vec<(node::Node, E)> { + fn node_search_k(&self, item: &N, k: usize) -> Vec<(N, E)> { let mut ret: BinaryHeap> = self.search_knn(item, k).unwrap(); - let mut result: Vec<(node::Node, E)> = Vec::with_capacity(k); + let mut result: Vec<(N, E)> = Vec::with_capacity(k); let mut result_idx: Vec<(usize, E)> = Vec::with_capacity(k); while !ret.is_empty() { let top = ret.peek().unwrap(); @@ -641,8 +643,11 @@ impl ann_index::ANNIndex for HNSW } } -impl - ann_index::SerializableIndex for HNSWIndex +impl< + E: node::FloatElement + DeserializeOwned, + T: node::IdxType + DeserializeOwned, + N: node::Node, + > ann_index::SerializableIndex for HNSWIndex { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); diff --git a/src/index/pq_idx.rs b/src/index/pq_idx.rs index 15f0392..a974258 100644 --- a/src/index/pq_idx.rs +++ b/src/index/pq_idx.rs @@ -19,7 +19,7 @@ use std::fs::File; use std::io::Write; #[derive(Default, Debug, Serialize, Deserialize)] -pub struct PQIndex { +pub struct PQIndex> { _dimension: usize, //dimension of data _n_sub: usize, //num of subdata _sub_dimension: usize, //dimension of subdata @@ -37,15 +37,15 @@ pub struct PQIndex { _n_items: usize, _max_item: usize, - _nodes: Vec>>, + _nodes: Vec>, _assigned_center: Vec>, mt: metrics::Metric, //compute metrics // _item2id: HashMap, - _nodes_tmp: Vec>, + _nodes_tmp: Vec, } -impl PQIndex { - pub fn new(dimension: usize, params: &PQParams) -> PQIndex { +impl> PQIndex { + pub fn new(dimension: usize, params: &PQParams) -> PQIndex { let n_sub = params.n_sub; let sub_bits = params.sub_bits; let train_epoch = params.train_epoch; @@ -89,7 +89,7 @@ impl PQIndex { new_pq } - fn init_item(&mut self, data: &node::Node) -> usize { + fn init_item(&mut self, data: &N) -> usize { let cur_id = self._n_items; // self._item2id.insert(item, cur_id); self._nodes.push(Box::new(data.clone())); @@ -97,7 +97,7 @@ impl PQIndex { cur_id } - fn add_item(&mut self, data: &node::Node) -> Result { + fn add_item(&mut self, data: &N) -> Result { if data.len() != self._dimension { return Err("dimension is different"); } @@ -148,13 +148,7 @@ impl PQIndex { self._is_trained = true; } - fn get_distance_from_vec_range( - &self, - x: &node::Node, - y: &[E], - begin: usize, - end: usize, - ) -> E { + fn get_distance_from_vec_range(&self, x: &N, y: &[E], begin: usize, end: usize) -> E { let mut z = x.vectors()[begin..end].to_vec(); if self._has_residual { (0..end - begin).for_each(|i| z[i] -= self._residual[i + begin]); @@ -164,7 +158,7 @@ impl PQIndex { fn search_knn_adc( &self, - search_data: &node::Node, + search_data: &N, k: usize, ) -> Result>, &'static str> { let mut dis2centers: Vec = Vec::new(); @@ -194,13 +188,15 @@ impl PQIndex { } } -impl ann_index::ANNIndex for PQIndex { +impl> ann_index::ANNIndex + for PQIndex +{ fn build(&mut self, _mt: metrics::Metric) -> Result<(), &'static str> { self.mt = _mt; self.train_center(); Result::Ok(()) } - fn add_node(&mut self, item: &node::Node) -> Result<(), &'static str> { + fn add_node(&mut self, item: &N) -> Result<(), &'static str> { match self.add_item(item) { Err(err) => Err(err), _ => Ok(()), @@ -210,7 +206,7 @@ impl ann_index::ANNIndex for PQIn true } - fn node_search_k(&self, item: &node::Node, k: usize) -> Vec<(node::Node, E)> { + fn node_search_k(&self, item: &N, k: usize) -> Vec<(N, E)> { let mut ret: BinaryHeap> = self.search_knn_adc(item, k).unwrap(); let mut result: Vec<(node::Node, E)> = Vec::new(); let mut result_idx: Vec<(usize, E)> = Vec::new(); @@ -240,8 +236,11 @@ impl ann_index::ANNIndex for PQIn } } -impl - ann_index::SerializableIndex for PQIndex +impl< + E: node::FloatElement + DeserializeOwned, + T: node::IdxType + DeserializeOwned, + N: node::Node + DeserializeOwned, + > ann_index::SerializableIndex for PQIndex { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); @@ -265,7 +264,7 @@ impl { +pub struct IVFPQIndex> { _dimension: usize, //dimension of data _n_sub: usize, //num of subdata _sub_dimension: usize, //dimension of subdata @@ -279,20 +278,20 @@ pub struct IVFPQIndex { _n_kmeans_center: usize, _centers: Vec>, _ivflist: Vec>, //ivf center id - _pq_list: Vec>, + _pq_list: Vec>, _is_trained: bool, _n_items: usize, _max_item: usize, - _nodes: Vec>>, + _nodes: Vec>, _assigned_center: Vec>, mt: metrics::Metric, //compute metrics // _item2id: HashMap, - _nodes_tmp: Vec>, + _nodes_tmp: Vec, } -impl IVFPQIndex { - pub fn new(dimension: usize, params: &IVFPQParams) -> IVFPQIndex { +impl> IVFPQIndex { + pub fn new(dimension: usize, params: &IVFPQParams) -> IVFPQIndex { let n_sub = params.n_sub; let sub_bits = params.sub_bits; let n_kmeans_center = params.n_kmeans_center; @@ -329,7 +328,7 @@ impl IVFPQIndex { } } - fn init_item(&mut self, data: &node::Node) -> usize { + fn init_item(&mut self, data: &N) -> usize { let cur_id = self._n_items; // self._item2id.insert(item, cur_id); self._nodes.push(Box::new(data.clone())); @@ -337,7 +336,7 @@ impl IVFPQIndex { cur_id } - fn add_item(&mut self, data: &node::Node) -> Result { + fn add_item(&mut self, data: &N) -> Result { if data.len() != self._dimension { return Err("dimension is different"); } @@ -395,19 +394,13 @@ impl IVFPQIndex { self._is_trained = true; } - fn get_distance_from_vec_range( - &self, - x: &node::Node, - y: &[E], - begin: usize, - end: usize, - ) -> E { + fn get_distance_from_vec_range(&self, x: &N, y: &[E], begin: usize, end: usize) -> E { return metrics::metric(&x.vectors()[begin..end], y, self.mt).unwrap(); } fn search_knn_adc( &self, - search_data: &node::Node, + search_data: &N, k: usize, ) -> Result>, &'static str> { let mut top_centers: BinaryHeap> = BinaryHeap::new(); @@ -439,13 +432,15 @@ impl IVFPQIndex { } } -impl ann_index::ANNIndex for IVFPQIndex { +impl> ann_index::ANNIndex + for IVFPQIndex +{ fn build(&mut self, _mt: metrics::Metric) -> Result<(), &'static str> { self.mt = _mt; self.train(); Result::Ok(()) } - fn add_node(&mut self, item: &node::Node) -> Result<(), &'static str> { + fn add_node(&mut self, item: &N) -> Result<(), &'static str> { match self.add_item(item) { Err(err) => Err(err), _ => Ok(()), @@ -455,7 +450,7 @@ impl ann_index::ANNIndex for IVFP true } - fn node_search_k(&self, item: &node::Node, k: usize) -> Vec<(node::Node, E)> { + fn node_search_k(&self, item: &N, k: usize) -> Vec<(N, E)> { let mut ret: BinaryHeap> = self.search_knn_adc(item, k).unwrap(); let mut result: Vec<(node::Node, E)> = Vec::new(); let mut result_idx: Vec<(usize, E)> = Vec::new(); @@ -485,8 +480,11 @@ impl ann_index::ANNIndex for IVFP } } -impl - ann_index::SerializableIndex for IVFPQIndex +impl< + E: node::FloatElement + DeserializeOwned, + T: node::IdxType + DeserializeOwned, + N: node::Node + DeserializeOwned, + > ann_index::SerializableIndex for IVFPQIndex { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); diff --git a/src/index/ssg_idx.rs b/src/index/ssg_idx.rs index b924151..d8d47fd 100644 --- a/src/index/ssg_idx.rs +++ b/src/index/ssg_idx.rs @@ -23,10 +23,10 @@ use std::io::Write; use std::sync::{Arc, Mutex}; #[derive(Debug, Serialize, Deserialize)] -pub struct SSGIndex { +pub struct SSGIndex> { #[serde(skip_serializing, skip_deserializing)] - nodes: Vec>>, - tmp_nodes: Vec>, // only use for serialization scene + nodes: Vec>, + tmp_nodes: Vec, // only use for serialization scene mt: metrics::Metric, dimension: usize, neighbor_neighbor_size: usize, @@ -44,8 +44,8 @@ pub struct SSGIndex { search_times: usize, } -impl SSGIndex { - pub fn new(dimension: usize, params: &SSGParams) -> SSGIndex { +impl> SSGIndex { + pub fn new(dimension: usize, params: &SSGParams) -> SSGIndex { SSGIndex:: { nodes: Vec::new(), tmp_nodes: Vec::new(), @@ -399,7 +399,7 @@ impl SSGIndex { // avg /= 1.0 * self.nodes.len() as f32; } - fn search(&self, query: &node::Node, k: usize) -> Vec<(node::Node, E)> { + fn search(&self, query: &N, k: usize) -> Vec<(N, E)> { // let mut search_flags = HashSet::with_capacity(self.nodes.len()); let mut search_flags = FixedBitSet::with_capacity(self.nodes.len()); let mut heap: BinaryHeap> = BinaryHeap::new(); // max-heap @@ -480,8 +480,11 @@ impl SSGIndex { } } -impl - ann_index::SerializableIndex for SSGIndex +impl< + E: node::FloatElement + DeserializeOwned, + T: node::IdxType + DeserializeOwned, + N: node::Node + DeserializeOwned, + > ann_index::SerializableIndex for SSGIndex { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); @@ -504,21 +507,23 @@ impl ann_index::ANNIndex for SSGIndex { +impl> ann_index::ANNIndex + for SSGIndex +{ fn build(&mut self, mt: metrics::Metric) -> Result<(), &'static str> { self.mt = mt; self._build(); Result::Ok(()) } - fn add_node(&mut self, item: &node::Node) -> Result<(), &'static str> { + fn add_node(&mut self, item: &N) -> Result<(), &'static str> { self.nodes.push(Box::new(item.clone())); Result::Ok(()) } fn built(&self) -> bool { true } - fn node_search_k(&self, item: &node::Node, k: usize) -> Vec<(node::Node, E)> { + fn node_search_k(&self, item: &N, k: usize) -> Vec<(N, E)> { self.search(item, k) } diff --git a/src/lib.rs b/src/lib.rs index 54dba73..342ae23 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ mod tests { use super::*; use crate::core::ann_index::ANNIndex; + use crate::core::node::MemoryNode; use rand::distributions::Standard; use rand::Rng; use std::collections::HashSet; @@ -112,13 +113,11 @@ mod tests { } } - fn make_idx_baseline< + fn make_idx_baseline(embs: Vec>, idx: &mut Box) + where E: core::node::FloatElement, - T: core::ann_index::ANNIndex + ?Sized, - >( - embs: Vec>, - idx: &mut Box, - ) { + T: core::ann_index::ANNIndex> + ?Sized, + { for i in 0..embs.len() { idx.add_node(&core::node::Node::::new_with_idx(&embs[i], i)) .unwrap(); From 303bbf085001bfb191f57206b71464e4a61b5b4e Mon Sep 17 00:00:00 2001 From: Gavin Mendel-Gleason Date: Sat, 27 May 2023 13:12:57 +0200 Subject: [PATCH 2/4] Experiment with associated tyeps --- src/core/ann_index.rs | 6 +++--- src/core/node.rs | 26 ++++++++++++++------------ src/index/bruteforce_idx.rs | 8 ++++---- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/core/ann_index.rs b/src/core/ann_index.rs index 8dec063..3844948 100644 --- a/src/core/ann_index.rs +++ b/src/core/ann_index.rs @@ -22,7 +22,7 @@ use serde::de::DeserializeOwned; /// ``` /// -pub trait ANNIndex>: +pub trait ANNIndex>: Send + Sync { /// build up the ANN index @@ -35,7 +35,7 @@ pub trait ANNIndex> /// /// it will allocate a space in the heap(Vector), and init a `Node` /// return `Err(&'static str)` if there is something wrong with the adding process, and the `static str` is the debug reason - fn add_node(&mut self, item: &impl node::Node) -> Result<(), &'static str>; + fn add_node(&mut self, item: &impl node::Node) -> Result<(), &'static str>; /// add node /// @@ -143,7 +143,7 @@ pub trait ANNIndex> pub trait SerializableIndex< E: node::FloatElement + DeserializeOwned, T: node::IdxType + DeserializeOwned, - N: node::Node, + N: node::Node, >: Send + Sync + ANNIndex { /// load file with path diff --git a/src/core/node.rs b/src/core/node.rs index e2fa8f7..6771623 100644 --- a/src/core/node.rs +++ b/src/core/node.rs @@ -86,18 +86,20 @@ to_idx_type!(u32); to_idx_type!(u64); to_idx_type!(u128); -pub trait Node: Send + Sync { - fn new(vectors: &[E]) -> Self; - fn new_with_index(vectors: &[E], id: T) -> Self; - fn metric(&self, other: impl Node, t: metrics::Metric) -> Result; - fn vectors(&self) -> Vec; - fn mut_vectors(&mut self) -> &mut Vec; - fn set_vectors(&mut self, v: &[E]); +pub trait Node: Send + Sync { + type E; + type T; + fn new(vectors: &[Self::E]) -> Self; + fn new_with_idx(vectors: &[Self::E], id: Self::T) -> Self; + fn metric(&self, other: impl Node, t: metrics::Metric) -> Result; + fn vectors(&self) -> Vec; + fn mut_vectors(&mut self) -> &mut Vec; + fn set_vectors(&mut self, v: &[Self::E]); fn len(&self) -> usize; fn is_empty(&self) -> bool; - fn idx(&self) -> &Option; - fn set_idx(&mut self, id: T); - fn valid_elements(vectors: &[E]) -> bool; + fn idx(&self) -> &Option; + fn set_idx(&mut self, id: Self::T); + fn valid_elements(vectors: &[Self::E]) -> bool; } /// Node is the main container for the point in the space @@ -110,7 +112,7 @@ pub struct MemoryNode { idx: Option, // data id, it can be any type; } -impl Node for MemoryNode { +impl Node for MemoryNode { /// new without idx /// /// new a point without a idx @@ -132,7 +134,7 @@ impl Node for MemoryNode { } /// calculate the point distance - fn metric(&self, other: &impl Node, t: metrics::Metric) -> Result { + fn metric(&self, other: &impl Node, t: metrics::Metric) -> Result { metrics::metric(&self.vectors, &other.vectors(), t) } diff --git a/src/index/bruteforce_idx.rs b/src/index/bruteforce_idx.rs index 3d563f7..87ae9b1 100644 --- a/src/index/bruteforce_idx.rs +++ b/src/index/bruteforce_idx.rs @@ -13,7 +13,7 @@ use std::fs::File; use std::io::Write; #[derive(Debug, Serialize, Deserialize)] -pub struct BruteForceIndex> { +pub struct BruteForceIndex { #[serde(skip_serializing, skip_deserializing)] nodes: Vec>, tmp_nodes: Vec, // only use for serialization scene @@ -21,9 +21,9 @@ pub struct BruteForceIndex> BruteForceIndex { - pub fn new(dimension: usize, _params: &BruteForceParams) -> BruteForceIndex { - BruteForceIndex:: { +impl> BruteForceIndex { + pub fn new(dimension: usize, _params: &BruteForceParams) -> BruteForceIndex { + BruteForceIndex:: { nodes: Vec::new(), mt: metrics::Metric::Unknown, tmp_nodes: Vec::new(), From 9101bdb974f13e6ca570615f5c0b1ae99c482faa Mon Sep 17 00:00:00 2001 From: Gavin Mendel-Gleason Date: Sat, 27 May 2023 14:19:25 +0200 Subject: [PATCH 3/4] More fixes to associated type approach --- src/core/ann_index.rs | 2 +- src/core/kmeans.rs | 2 +- src/core/node.rs | 12 +++++++++--- src/index/bruteforce_idx.rs | 10 +++++----- src/index/hnsw_idx.rs | 14 +++++++------- src/index/pq_idx.rs | 26 +++++++++++++------------- src/index/ssg_idx.rs | 16 ++++++++-------- src/lib.rs | 6 +++--- 8 files changed, 47 insertions(+), 41 deletions(-) diff --git a/src/core/ann_index.rs b/src/core/ann_index.rs index 3844948..fc68dc7 100644 --- a/src/core/ann_index.rs +++ b/src/core/ann_index.rs @@ -35,7 +35,7 @@ pub trait ANNIndex) -> Result<(), &'static str>; + fn add_node(&mut self, item: &N) -> Result<(), &'static str>; /// add node /// diff --git a/src/core/kmeans.rs b/src/core/kmeans.rs index 4b170b0..9731c5d 100644 --- a/src/core/kmeans.rs +++ b/src/core/kmeans.rs @@ -213,7 +213,7 @@ impl Kmeans { } } -pub fn general_kmeans>( +pub fn general_kmeans( k: usize, epoch: usize, nodes: &[Box], diff --git a/src/core/node.rs b/src/core/node.rs index 6771623..8f30b01 100644 --- a/src/core/node.rs +++ b/src/core/node.rs @@ -86,13 +86,17 @@ to_idx_type!(u32); to_idx_type!(u64); to_idx_type!(u128); -pub trait Node: Send + Sync { +pub trait Node: Clone + Send + Sync { type E; type T; fn new(vectors: &[Self::E]) -> Self; fn new_with_idx(vectors: &[Self::E], id: Self::T) -> Self; - fn metric(&self, other: impl Node, t: metrics::Metric) -> Result; - fn vectors(&self) -> Vec; + fn metric( + &self, + other: &impl Node, + t: metrics::Metric, + ) -> Result; + fn vectors(&self) -> &Vec; fn mut_vectors(&mut self) -> &mut Vec; fn set_vectors(&mut self, v: &[Self::E]); fn len(&self) -> usize; @@ -113,6 +117,8 @@ pub struct MemoryNode { } impl Node for MemoryNode { + type E = E; + type T = T; /// new without idx /// /// new a point without a idx diff --git a/src/index/bruteforce_idx.rs b/src/index/bruteforce_idx.rs index 87ae9b1..2b2ebc8 100644 --- a/src/index/bruteforce_idx.rs +++ b/src/index/bruteforce_idx.rs @@ -21,7 +21,7 @@ pub struct BruteForceIndex { dimension: usize, } -impl> BruteForceIndex { +impl> BruteForceIndex { pub fn new(dimension: usize, _params: &BruteForceParams) -> BruteForceIndex { BruteForceIndex:: { nodes: Vec::new(), @@ -32,8 +32,8 @@ impl> BruteForceInd } } -impl> ann_index::ANNIndex - for BruteForceIndex +impl> + ann_index::ANNIndex for BruteForceIndex { fn build(&mut self, mt: metrics::Metric) -> Result<(), &'static str> { self.mt = mt; @@ -86,8 +86,8 @@ impl> ann_index::AN impl< E: node::FloatElement + DeserializeOwned, T: node::IdxType + DeserializeOwned, - N: node::Node + DeserializeOwned, - > ann_index::SerializableIndex for BruteForceIndex + N: node::Node + DeserializeOwned, + > ann_index::SerializableIndex for BruteForceIndex { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); diff --git a/src/index/hnsw_idx.rs b/src/index/hnsw_idx.rs index 04f6d4b..86819c3 100644 --- a/src/index/hnsw_idx.rs +++ b/src/index/hnsw_idx.rs @@ -21,7 +21,7 @@ use std::io::Write; use std::sync::RwLock; #[derive(Default, Debug, Serialize, Deserialize)] -pub struct HNSWIndex> { +pub struct HNSWIndex> { _dimension: usize, // dimension _n_items: usize, // next item count _n_constructed_items: usize, @@ -55,8 +55,8 @@ pub struct HNSWIndex, } -impl> HNSWIndex { - pub fn new(dimension: usize, params: &HNSWParams) -> HNSWIndex { +impl> HNSWIndex { + pub fn new(dimension: usize, params: &HNSWParams) -> HNSWIndex { HNSWIndex { _dimension: dimension, _n_items: 0, @@ -599,8 +599,8 @@ impl> HNSWIndex> ann_index::ANNIndex - for HNSWIndex +impl> + ann_index::ANNIndex for HNSWIndex { fn build(&mut self, mt: metrics::Metric) -> Result<(), &'static str> { self.mt = mt; @@ -646,8 +646,8 @@ impl> ann_index::AN impl< E: node::FloatElement + DeserializeOwned, T: node::IdxType + DeserializeOwned, - N: node::Node, - > ann_index::SerializableIndex for HNSWIndex + N: node::Node, + > ann_index::SerializableIndex for HNSWIndex { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); diff --git a/src/index/pq_idx.rs b/src/index/pq_idx.rs index a974258..a4494ad 100644 --- a/src/index/pq_idx.rs +++ b/src/index/pq_idx.rs @@ -19,7 +19,7 @@ use std::fs::File; use std::io::Write; #[derive(Default, Debug, Serialize, Deserialize)] -pub struct PQIndex> { +pub struct PQIndex> { _dimension: usize, //dimension of data _n_sub: usize, //num of subdata _sub_dimension: usize, //dimension of subdata @@ -44,8 +44,8 @@ pub struct PQIndex> _nodes_tmp: Vec, } -impl> PQIndex { - pub fn new(dimension: usize, params: &PQParams) -> PQIndex { +impl> PQIndex { + pub fn new(dimension: usize, params: &PQParams) -> PQIndex { let n_sub = params.n_sub; let sub_bits = params.sub_bits; let train_epoch = params.train_epoch; @@ -188,8 +188,8 @@ impl> PQIndex> ann_index::ANNIndex - for PQIndex +impl> + ann_index::ANNIndex for PQIndex { fn build(&mut self, _mt: metrics::Metric) -> Result<(), &'static str> { self.mt = _mt; @@ -239,8 +239,8 @@ impl> ann_index::AN impl< E: node::FloatElement + DeserializeOwned, T: node::IdxType + DeserializeOwned, - N: node::Node + DeserializeOwned, - > ann_index::SerializableIndex for PQIndex + N: node::Node + DeserializeOwned, + > ann_index::SerializableIndex for PQIndex { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); @@ -264,7 +264,7 @@ impl< } #[derive(Default, Debug, Serialize, Deserialize)] -pub struct IVFPQIndex> { +pub struct IVFPQIndex> { _dimension: usize, //dimension of data _n_sub: usize, //num of subdata _sub_dimension: usize, //dimension of subdata @@ -278,7 +278,7 @@ pub struct IVFPQIndex>, _ivflist: Vec>, //ivf center id - _pq_list: Vec>, + _pq_list: Vec>, _is_trained: bool, _n_items: usize, @@ -290,7 +290,7 @@ pub struct IVFPQIndex, } -impl> IVFPQIndex { +impl> IVFPQIndex { pub fn new(dimension: usize, params: &IVFPQParams) -> IVFPQIndex { let n_sub = params.n_sub; let sub_bits = params.sub_bits; @@ -432,8 +432,8 @@ impl> IVFPQIndex> ann_index::ANNIndex - for IVFPQIndex +impl> + ann_index::ANNIndex for IVFPQIndex { fn build(&mut self, _mt: metrics::Metric) -> Result<(), &'static str> { self.mt = _mt; @@ -483,7 +483,7 @@ impl> ann_index::AN impl< E: node::FloatElement + DeserializeOwned, T: node::IdxType + DeserializeOwned, - N: node::Node + DeserializeOwned, + N: node::Node + DeserializeOwned, > ann_index::SerializableIndex for IVFPQIndex { fn load(path: &str) -> Result { diff --git a/src/index/ssg_idx.rs b/src/index/ssg_idx.rs index d8d47fd..544061c 100644 --- a/src/index/ssg_idx.rs +++ b/src/index/ssg_idx.rs @@ -23,7 +23,7 @@ use std::io::Write; use std::sync::{Arc, Mutex}; #[derive(Debug, Serialize, Deserialize)] -pub struct SSGIndex> { +pub struct SSGIndex> { #[serde(skip_serializing, skip_deserializing)] nodes: Vec>, tmp_nodes: Vec, // only use for serialization scene @@ -44,8 +44,8 @@ pub struct SSGIndex search_times: usize, } -impl> SSGIndex { - pub fn new(dimension: usize, params: &SSGParams) -> SSGIndex { +impl> SSGIndex { + pub fn new(dimension: usize, params: &SSGParams) -> SSGIndex { SSGIndex:: { nodes: Vec::new(), tmp_nodes: Vec::new(), @@ -483,12 +483,12 @@ impl> SSGIndex + DeserializeOwned, - > ann_index::SerializableIndex for SSGIndex + N: node::Node + DeserializeOwned, + > ann_index::SerializableIndex for SSGIndex { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); - let mut instance: SSGIndex = bincode::deserialize_from(&file).unwrap(); + let mut instance: SSGIndex = bincode::deserialize_from(&file).unwrap(); instance.nodes = instance .tmp_nodes .iter() @@ -507,8 +507,8 @@ impl< } } -impl> ann_index::ANNIndex - for SSGIndex +impl> + ann_index::ANNIndex for SSGIndex { fn build(&mut self, mt: metrics::Metric) -> Result<(), &'static str> { self.mt = mt; diff --git a/src/lib.rs b/src/lib.rs index 342ae23..90a7798 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -113,14 +113,14 @@ mod tests { } } - fn make_idx_baseline(embs: Vec>, idx: &mut Box) + fn make_idx_baseline(embs: Vec>, idx: &mut Box) where E: core::node::FloatElement, T: core::ann_index::ANNIndex> + ?Sized, + N: core::node::Node, { for i in 0..embs.len() { - idx.add_node(&core::node::Node::::new_with_idx(&embs[i], i)) - .unwrap(); + idx.add_node(&N::new_with_idx(&embs[i], i)).unwrap(); } idx.build(core::metrics::Metric::Euclidean).unwrap(); } From ef774bfad37249ef9c5dfb886d2da893136127ec Mon Sep 17 00:00:00 2001 From: Gavin Mendel-Gleason Date: Sat, 27 May 2023 17:47:58 +0200 Subject: [PATCH 4/4] Make all tests generic --- examples/src/ann_bench.rs | 36 +++++++++++++++++++--------------- examples/src/demo.rs | 3 ++- src/core/kmeans.rs | 8 ++++---- src/core/node.rs | 16 +++++++++------ src/index/bruteforce_idx.rs | 4 ++-- src/index/hnsw_idx.rs | 4 ++-- src/index/pq_idx.rs | 12 ++++++------ src/index/ssg_idx.rs | 12 ++++++------ src/lib.rs | 39 ++++++++++++++++++++++--------------- 9 files changed, 75 insertions(+), 59 deletions(-) diff --git a/examples/src/ann_bench.rs b/examples/src/ann_bench.rs index 5b5028f..8eb313d 100644 --- a/examples/src/ann_bench.rs +++ b/examples/src/ann_bench.rs @@ -1,6 +1,7 @@ #![deny(clippy::all)] use hora::core; use hora::core::ann_index::ANNIndex; +use hora::core::node::MemoryNode; use std::collections::HashSet; use std::time::SystemTime; @@ -79,9 +80,9 @@ fn bench_ssg( let mut metrics_stats: Vec = Vec::new(); for params in params_set.iter() { println!("start params {:?}", params); - let mut ssg_idx = Box::new(hora::index::ssg_idx::SSGIndex::::new( - dimension, params, - )); + let mut ssg_idx = Box::new( + hora::index::ssg_idx::SSGIndex::>::new(dimension, params), + ); make_idx_baseline(train, &mut ssg_idx); metrics_stats.push(bench_calc(ssg_idx, test, neighbors)); println!("finish params {:?}", params); @@ -130,9 +131,10 @@ fn bench_hnsw( let mut metrics_stats: Vec = Vec::new(); for params in params_set.iter() { - let mut hnsw_idx = Box::new(hora::index::hnsw_idx::HNSWIndex::::new( - dimension, params, - )); + let mut hnsw_idx = Box::new(hora::index::hnsw_idx::HNSWIndex::< + usize, + MemoryNode, + >::new(dimension, params)); make_idx_baseline(train, &mut hnsw_idx); metrics_stats.push(bench_calc(hnsw_idx, test, neighbors)); println!("finish params {:?}", params); @@ -164,9 +166,11 @@ fn bench_ivfpq( let mut metrics_stats: Vec = Vec::new(); for params in params_set.iter() { - let mut ivfpq_idx = Box::new(hora::index::pq_idx::IVFPQIndex::::new( - dimension, params, - )); + let mut ivfpq_idx = Box::new(hora::index::pq_idx::IVFPQIndex::< + E, + usize, + MemoryNode, + >::new(dimension, params)); make_idx_baseline(train, &mut ivfpq_idx); metrics_stats.push(bench_calc(ivfpq_idx, test, neighbors)); println!("finish params {:?}", params); @@ -184,7 +188,7 @@ fn bench_ivfpq( } } -fn bench_calc + ?Sized>( +fn bench_calc> + ?Sized>( ann_idx: Box, test: &Vec>, neighbors: &Vec>, @@ -223,17 +227,17 @@ fn bench_calc + ?Sized>( } } -fn make_idx_baseline + ?Sized>( +fn make_idx_baseline< + E: core::node::FloatElement, + T: ANNIndex> + ?Sized, +>( embs: &Vec>, idx: &mut Box, ) { let start = SystemTime::now(); for i in 0..embs.len() { - idx.add_node(&core::node::Node::::new_with_idx( - embs[i].as_slice(), - i, - )) - .unwrap(); + idx.add_node(&core::node::Node::new_with_idx(embs[i].as_slice(), i)) + .unwrap(); } idx.build(core::metrics::Metric::DotProduct).unwrap(); let since_start = SystemTime::now() diff --git a/examples/src/demo.rs b/examples/src/demo.rs index 1cac295..19cf88c 100644 --- a/examples/src/demo.rs +++ b/examples/src/demo.rs @@ -1,4 +1,5 @@ use hora::core::ann_index::ANNIndex; +use hora::core::node::MemoryNode; use rand::{thread_rng, Rng}; use rand_distr::{Distribution, Normal}; @@ -18,7 +19,7 @@ pub fn demo() { } // init index - let mut index = hora::index::hnsw_idx::HNSWIndex::::new( + let mut index = hora::index::hnsw_idx::HNSWIndex::>::new( dimension, &hora::index::hnsw_params::HNSWParams::::default(), ); diff --git a/src/core/kmeans.rs b/src/core/kmeans.rs index 9731c5d..2d19fd7 100644 --- a/src/core/kmeans.rs +++ b/src/core/kmeans.rs @@ -213,7 +213,7 @@ impl Kmeans { } } -pub fn general_kmeans( +pub fn general_kmeans>( k: usize, epoch: usize, nodes: &[Box], @@ -239,7 +239,7 @@ pub fn general_kmeans( let mut idx = 0; let mut distance = E::max_value(); for (i, _item) in means.iter().enumerate() { - let _distance = node.metric(&means[i], mt).unwrap(); + let _distance = node.metric(&**means[i], mt).unwrap(); if _distance < distance { idx = i; distance = _distance; @@ -277,7 +277,7 @@ pub fn general_kmeans( let mut mean_idx = 0; let mut mean_distance = E::max_value(); nodes.iter().zip(0..nodes.len()).for_each(|(node, i)| { - let distance = node.metric(mean, mt).unwrap(); + let distance = node.metric(&***mean, mt).unwrap(); if distance < mean_distance { mean_idx = i; mean_distance = distance; @@ -347,7 +347,7 @@ mod tests { .map(|x| x.iter().map(|p| *p as f32).collect()) .collect(); - let nodes: Vec>> = ns + let nodes: Vec>> = ns .iter() .zip(0..ns.len()) .map(|(vs, idx)| Box::new(node::Node::new_with_idx(vs, idx))) diff --git a/src/core/node.rs b/src/core/node.rs index 8f30b01..a967dc9 100644 --- a/src/core/node.rs +++ b/src/core/node.rs @@ -86,8 +86,8 @@ to_idx_type!(u32); to_idx_type!(u64); to_idx_type!(u128); -pub trait Node: Clone + Send + Sync { - type E; +pub trait Node: Clone + Send + Sync + Serialize + Default { + type E: PartialOrd; type T; fn new(vectors: &[Self::E]) -> Self; fn new_with_idx(vectors: &[Self::E], id: Self::T) -> Self; @@ -140,8 +140,12 @@ impl Node for MemoryNode { } /// calculate the point distance - fn metric(&self, other: &impl Node, t: metrics::Metric) -> Result { - metrics::metric(&self.vectors, &other.vectors(), t) + fn metric( + &self, + other: &impl Node, + t: metrics::Metric, + ) -> Result { + metrics::metric(&self.vectors, &*other.vectors(), t) } // return internal embeddings @@ -202,7 +206,7 @@ fn node_test() { // f64 let v = vec![1.0, 1.0]; let v2 = vec![2.0, 2.0]; - let n = Node::::new(&v); - let n2 = Node::::new(&v2); + let n: MemoryNode = Node::new(&v); + let n2: MemoryNode = Node::new(&v2); assert_eq!(n.metric(&n2, metrics::Metric::Manhattan).unwrap(), 2.0); } diff --git a/src/index/bruteforce_idx.rs b/src/index/bruteforce_idx.rs index 2b2ebc8..b5c66c5 100644 --- a/src/index/bruteforce_idx.rs +++ b/src/index/bruteforce_idx.rs @@ -55,7 +55,7 @@ impl> heap.push(neighbor::Neighbor::new( // use max heap, and every time pop out the greatest one in the heap i, - item.metric(node, self.mt).unwrap(), + item.metric(&**node, self.mt).unwrap(), )); if heap.len() > k { let _xp = heap.pop().unwrap(); @@ -91,7 +91,7 @@ impl< { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); - let mut instance: BruteForceIndex = bincode::deserialize_from(file).unwrap(); + let mut instance: BruteForceIndex = bincode::deserialize_from(file).unwrap(); instance.nodes = instance .tmp_nodes .iter() diff --git a/src/index/hnsw_idx.rs b/src/index/hnsw_idx.rs index 86819c3..6856714 100644 --- a/src/index/hnsw_idx.rs +++ b/src/index/hnsw_idx.rs @@ -646,12 +646,12 @@ impl> impl< E: node::FloatElement + DeserializeOwned, T: node::IdxType + DeserializeOwned, - N: node::Node, + N: node::Node + DeserializeOwned, > ann_index::SerializableIndex for HNSWIndex { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); - let mut instance: HNSWIndex = bincode::deserialize_from(&file).unwrap(); + let mut instance: HNSWIndex = bincode::deserialize_from(&file).unwrap(); instance._nodes = instance ._nodes_tmp .iter() diff --git a/src/index/pq_idx.rs b/src/index/pq_idx.rs index a4494ad..bf18e48 100644 --- a/src/index/pq_idx.rs +++ b/src/index/pq_idx.rs @@ -55,7 +55,7 @@ impl> PQInd assert!(sub_bits <= 32); let n_center_per_sub = (1 << sub_bits) as usize; let code_bytes = sub_bytes * n_sub; - let mut new_pq = PQIndex:: { + let mut new_pq = PQIndex:: { _dimension: dimension, _n_sub: n_sub, _sub_dimension: sub_dimension, @@ -208,7 +208,7 @@ impl> fn node_search_k(&self, item: &N, k: usize) -> Vec<(N, E)> { let mut ret: BinaryHeap> = self.search_knn_adc(item, k).unwrap(); - let mut result: Vec<(node::Node, E)> = Vec::new(); + let mut result: Vec<(N, E)> = Vec::new(); let mut result_idx: Vec<(usize, E)> = Vec::new(); while !ret.is_empty() { let top = ret.peek().unwrap(); @@ -244,7 +244,7 @@ impl< { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); - let mut instance: PQIndex = bincode::deserialize_from(&file).unwrap(); + let mut instance: PQIndex = bincode::deserialize_from(&file).unwrap(); instance._nodes = instance ._nodes_tmp .iter() @@ -373,7 +373,7 @@ impl> IVFPQ self._ivflist[center_id].push(i); }); for i in 0..n_center { - let mut center_pq = PQIndex::::new( + let mut center_pq = PQIndex::::new( self._dimension, &PQParams::default() .n_sub(self._n_sub) @@ -452,7 +452,7 @@ impl> fn node_search_k(&self, item: &N, k: usize) -> Vec<(N, E)> { let mut ret: BinaryHeap> = self.search_knn_adc(item, k).unwrap(); - let mut result: Vec<(node::Node, E)> = Vec::new(); + let mut result: Vec<(N, E)> = Vec::new(); let mut result_idx: Vec<(usize, E)> = Vec::new(); while !ret.is_empty() { let top = ret.peek().unwrap(); @@ -488,7 +488,7 @@ impl< { fn load(path: &str) -> Result { let file = File::open(path).unwrap_or_else(|_| panic!("unable to open file {:?}", path)); - let mut instance: IVFPQIndex = bincode::deserialize_from(&file).unwrap(); + let mut instance: IVFPQIndex = bincode::deserialize_from(&file).unwrap(); instance._nodes = instance ._nodes_tmp .iter() diff --git a/src/index/ssg_idx.rs b/src/index/ssg_idx.rs index 544061c..a91f676 100644 --- a/src/index/ssg_idx.rs +++ b/src/index/ssg_idx.rs @@ -46,7 +46,7 @@ pub struct SSGIndex> { impl> SSGIndex { pub fn new(dimension: usize, params: &SSGParams) -> SSGIndex { - SSGIndex:: { + SSGIndex:: { nodes: Vec::new(), tmp_nodes: Vec::new(), mt: metrics::Metric::Unknown, @@ -80,7 +80,7 @@ impl> SSGIn } heap.push(neighbor::Neighbor::new( i, - item.metric(node, self.mt).unwrap(), + item.metric(&**node, self.mt).unwrap(), )); if heap.len() > self.init_k { heap.pop(); @@ -121,7 +121,7 @@ impl> SSGIn continue; } flags.insert(*nn_id); - let dist = self.nodes[q].metric(&self.nodes[*nn_id], self.mt).unwrap(); + let dist = self.nodes[q].metric(&*self.nodes[*nn_id], self.mt).unwrap(); expand_neighbors_tmp.push(neighbor::Neighbor::new(*nn_id, dist)); if expand_neighbors_tmp.len() >= self.neighbor_neighbor_size { return; @@ -217,7 +217,7 @@ impl> SSGIn expand_neighbors_tmp.push(neighbor::Neighbor::new( *linked_id, self.nodes[query_id] - .metric(&self.nodes[*linked_id], self.mt) + .metric(&*self.nodes[*linked_id], self.mt) .unwrap(), )); }); @@ -241,7 +241,7 @@ impl> SSGIn break; } let djk = self.nodes[iter.idx()] - .metric(&self.nodes[p.idx()], self.mt) + .metric(&*self.nodes[p.idx()], self.mt) .unwrap(); let cos_ij = (p.distance().powi(2) + iter.distance().powi(2) - djk.powi(2)) / (E::from_usize(2).unwrap() * (p.distance() * iter.distance())); @@ -321,7 +321,7 @@ impl> SSGIn break; } let djk = self.nodes[rt.idx()] - .metric(&self.nodes[p.idx()], self.mt) + .metric(&*self.nodes[p.idx()], self.mt) .unwrap(); let cos_ij = (p.distance().powi(2) + rt.distance().powi(2) - djk.powi(2)) / (E::from_usize(2).unwrap() * (p.distance() * rt.distance())); diff --git a/src/lib.rs b/src/lib.rs index 90a7798..caf7dc1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,10 +7,10 @@ mod tests { use crate::core::ann_index::ANNIndex; use crate::core::node::MemoryNode; + use crate::core::node::Node; use rand::distributions::Standard; use rand::Rng; use std::collections::HashSet; - use std::sync::Arc; use std::sync::Mutex; fn make_normal_distribution_clustering( @@ -59,29 +59,36 @@ mod tests { let (_, ns) = make_normal_distribution_clustering(node_n, nodes_every_cluster, dimension, 100.0); - let mut bf_idx = Box::new(index::bruteforce_idx::BruteForceIndex::::new( + let mut bf_idx = Box::new(index::bruteforce_idx::BruteForceIndex::< + MemoryNode, + >::new( dimension, &index::bruteforce_params::BruteForceParams::default(), )); // let bpt_idx = Box::new( // index::bpt_idx::BPTIndex::::new(dimension, &index::bpt_params::BPTParams::default()), // ); - let hnsw_idx = Box::new(index::hnsw_idx::HNSWIndex::::new( - dimension, - &index::hnsw_params::HNSWParams::::default(), - )); + let hnsw_idx = Box::new( + index::hnsw_idx::HNSWIndex::>::new( + dimension, + &index::hnsw_params::HNSWParams::::default(), + ), + ); - let pq_idx = Box::new(index::pq_idx::PQIndex::::new( + let pq_idx = Box::new(index::pq_idx::PQIndex::>::new( dimension, &index::pq_params::PQParams::::default(), )); - let ssg_idx = Box::new(index::ssg_idx::SSGIndex::::new( - dimension, - &index::ssg_params::SSGParams::default(), - )); + let ssg_idx = Box::new( + index::ssg_idx::SSGIndex::>::new( + dimension, + &index::ssg_params::SSGParams::default(), + ), + ); - let mut indices: Vec>> = - vec![pq_idx, ssg_idx, hnsw_idx]; + let mut indices: Vec< + Box>>, + > = vec![pq_idx, ssg_idx, hnsw_idx]; let accuracy = Arc::new(Mutex::new(Vec::new())); for i in 0..indices.len() { make_idx_baseline(ns.clone(), &mut indices[i]); @@ -113,11 +120,11 @@ mod tests { } } - fn make_idx_baseline(embs: Vec>, idx: &mut Box) + fn make_idx_baseline(embs: Vec>, idx: &mut Box) where E: core::node::FloatElement, - T: core::ann_index::ANNIndex> + ?Sized, - N: core::node::Node, + N: core::node::Node, + I: core::ann_index::ANNIndex + ?Sized, { for i in 0..embs.len() { idx.add_node(&N::new_with_idx(&embs[i], i)).unwrap();