Skip to content

Commit

Permalink
Add batch add api (to rust)
Browse files Browse the repository at this point in the history
  • Loading branch information
anchpop committed Dec 26, 2024
1 parent 70e3a63 commit ab31758
Showing 1 changed file with 85 additions and 39 deletions.
124 changes: 85 additions & 39 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,31 +86,46 @@ impl<D: DirectoryHandle> Victor<D> {
.collect::<Vec<String>>();

let vectors = model.embed(content.clone(), None).unwrap();
for (vector, content) in vectors.iter().zip(content.iter()) {
self.add_embedding(content.clone(), vector.clone(), tags.clone())
.await;
}

let to_add = content.into_iter().zip(vectors.into_iter()).collect();
self.add_embedding_many(to_add, tags).await;
}

#[cfg(not(target_arch = "wasm32"))]
pub async fn add(&mut self, content: impl Into<String>, tags: Vec<String>) {
self.add_many(vec![content], tags).await;
}

pub async fn add_embedding_many(
&mut self,
to_add: Vec<(impl Into<String>, Vec<f32>)>,
tags: Vec<String>,
) {
let (contents, embeddings) = to_add
.into_iter()
.map(|(content, embedding)| {
let uuid = Uuid::new_v4();
(
(content.into(), uuid.clone()),
Embedding {
id: uuid,
vector: embedding,
},
)
})
.unzip();

self.write_embeddings(embeddings, tags).await.unwrap();
self.write_contents(contents).await.unwrap();
}

pub async fn add_embedding(
&mut self,
content: impl Into<String>,
vector: Vec<f32>,
tags: Vec<String>,
) {
let content = content.into();

let id = Uuid::new_v4();

let embedding = Embedding { id, vector };

self.write_embedding(embedding, tags).await.unwrap();
self.write_content(content, id).await.unwrap();
self.add_embedding_many(vec![(content, vector)], tags).await;
}

#[cfg(not(target_arch = "wasm32"))]
Expand Down Expand Up @@ -150,7 +165,8 @@ impl<D: DirectoryHandle> Victor<D> {
.is_ok();

if is_projected {
vector = self.project_single_vector(vector).await;
let eigen_file = self.eigen_file().await;
vector = self.project_single_vector(vector, eigen_file);
}

let mut nearest_neighbors = BinaryHeap::with_capacity(top_n);
Expand Down Expand Up @@ -355,14 +371,17 @@ impl<D: DirectoryHandle> Victor<D> {
bincode::deserialize::<u32>(embedding_size_bytes).expect("Failed to deserialize header")
}

async fn project_single_vector(&self, vector: Vec<f32>) -> Vec<f32> {
async fn eigen_file(&self) -> Vec<u8> {
let eigen_file_handle = self
.root
.get_file_handle_with_options("eigen.bin", &GetFileHandleOptions { create: true })
.await
.unwrap();

let eigen_file = eigen_file_handle.read().await.unwrap();
eigen_file_handle.read().await.unwrap()
}

fn project_single_vector(&self, vector: Vec<f32>, eigen_file: Vec<u8>) -> Vec<f32> {
let vector_projection: VectorProjection = bincode::deserialize(&eigen_file).unwrap();

let centered_vector = vector
Expand All @@ -379,9 +398,9 @@ impl<D: DirectoryHandle> Victor<D> {
projected_vector
}

async fn write_embedding(
async fn write_embeddings(
&mut self,
mut embedding: Embedding,
mut embeddings: Vec<Embedding>,
tags: Vec<String>,
) -> Result<(), D::Error> {
let mut file_handle = Index::get_exact_db_file(&mut self.root, tags).await?;
Expand All @@ -393,11 +412,18 @@ impl<D: DirectoryHandle> Victor<D> {
.is_ok();

if is_projected {
let vector = self.project_single_vector(embedding.vector.clone()).await;
embedding = Embedding {
id: embedding.id,
vector,
};
let eigen_file = self.eigen_file().await;
embeddings = embeddings
.into_iter()
.map(|embedding| {
let vector =
self.project_single_vector(embedding.vector.clone(), eigen_file.clone());
Embedding {
id: embedding.id,
vector,
}
})
.collect();
}

let mut writable = file_handle
Expand All @@ -408,38 +434,56 @@ impl<D: DirectoryHandle> Victor<D> {

writable.seek(file_handle.size().await?).await?;

let embedding_serialized =
bincode::serialize(&embedding).expect("Failed to serialize embedding");
let embeddings_serialized = embeddings
.into_iter()
.map(|embedding| bincode::serialize(&embedding).expect("Failed to serialize embedding"))
.collect::<Vec<_>>();

if file_handle.size().await? == 0 {
let len_as_u32: u32 = embedding_serialized.len() as u32;
// check that the embeddings are all the same size
// and get that size
let embedding_size = match &embeddings_serialized
.iter()
.map(|embedding| embedding.len())
.collect::<HashSet<_>>()
.into_iter()
.collect::<Vec<_>>()[..]
{
[size] => *size as u32,
_ => panic!("All embeddings must be the same size"),
};

if file_handle.size().await? == 0 {
let serialized_size =
bincode::serialize(&len_as_u32).expect("Failed to serialize size");
bincode::serialize(&embedding_size).expect("Failed to serialize size");

writable.write_at_cursor_pos(serialized_size).await?;
} else {
let embedding_size = Self::get_embedding_size(file_handle.read().await?);
if embedding_serialized.len() as u32 != embedding_size {
panic!(
"Embedding size mismatch: expected {} but got {}",
embedding_size,
embedding_serialized.len()
);
}
let previous_embedding_size = Self::get_embedding_size(file_handle.read().await?);
assert_eq!(
embedding_size, previous_embedding_size,
"Embedding size mismatch: expected {} but got {}",
previous_embedding_size, embedding_size
);
}

writable.write_at_cursor_pos(embedding_serialized).await?;
let all_embeddings_serialized = embeddings_serialized
.into_iter()
.flatten()
.collect::<Vec<_>>();
writable
.write_at_cursor_pos(all_embeddings_serialized)
.await?;

writable.close().await?;

if file_handle.size().await? > 1000000 && !is_projected {
if cfg!(target_arch = "wasm32") && file_handle.size().await? > 1000000 && !is_projected {
self.project_embeddings().await;
}

Ok(())
}

async fn write_content(&mut self, content: String, id: Uuid) -> Result<(), D::Error> {
async fn write_contents(&mut self, content: Vec<(String, Uuid)>) -> Result<(), D::Error> {
let mut content_file_handle = self
.root
.get_file_handle_with_options("content.bin", &GetFileHandleOptions { create: true })
Expand All @@ -453,7 +497,9 @@ impl<D: DirectoryHandle> Victor<D> {
bincode::deserialize(&existing_content).expect("Failed to deserialize existing data")
};

hashmap.insert(id, content);
for (content, id) in content {
hashmap.insert(id, content);
}

let updated_data = bincode::serialize(&hashmap).expect("Failed to serialize hashmap");

Expand Down

0 comments on commit ab31758

Please sign in to comment.