Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use fast embeddings to automatically generate embeddings #44

Merged
merged 13 commits into from
Dec 26, 2024
64 changes: 64 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: CI

on:
push:
branches: [main]
pull_request:
branches: [main]

env:
CARGO_TERM_COLOR: always

jobs:
rust-checks:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install Rust
uses: dtolnay/rust-toolchain@stable
with:
components: clippy, rustfmt

- name: Cache dependencies
uses: actions/cache@v3
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}

- name: Check formatting
run: cargo fmt --all -- --check

- name: Run clippy
run: cargo clippy -- -D warnings

- name: Run tests
run: cargo test

- name: Check semver
uses: obi1kenobi/cargo-semver-checks-action@v2

wasm-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install Node.js
uses: actions/setup-node@v4
with:
node-version: "20"

- name: Install wasm-pack
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh

- name: Build wasm package
run: wasm-pack build --target web

- name: Install and build www
working-directory: www
run: |
npm install
npm run build
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ wasm/target
Cargo.lock
pkg/
node_modules/
victor_test_data/
victor_test_data/
.fastembed_cache/
.aider*
.env
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "victor-db"
version = "0.1.3"
version = "0.2.3"
authors = ["Sam Hall <s@muel.email", "Andre Popovitch <andre@popovit.ch>"]
edition = "2021"
license-file = "LICENSE.md"
Expand Down Expand Up @@ -52,6 +52,7 @@ js-sys = "0.3"

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
tokio = { version = "1", features = ["rt", "macros", "fs", "io-util"] }
fastembed = "4.3.0"

[dev-dependencies]
wasm-bindgen-test = "0.3"
Expand Down
22 changes: 8 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ cargo add victor-db

#### Usage

The Rust API can automatically create embeddings for you with [fastembed-rs](https://github.com/anush008/fastembed-rs?tab=readme-ov-file)'s default model (currently [BAAI/bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5)).

```rust
use std::path::PathBuf;

Expand All @@ -65,30 +67,22 @@ let mut victor = Db::new(PathBuf::from("./victor_test_data"));
victor.clear_db().await.unwrap();

victor
.write(
"Test Vector 1",
vec![1.0, 0.0, 0.0],
vec!["Test".to_string()],
)
.await;
victor
.write(
"Test Vector 2",
vec![0.0, 1.0, 0.0],
vec!["Test".to_string()],
.add_many(
vec!["Pinapple", "Rocks"], // documents
vec!["PizzaToppings"], // tags (only used for filtering)
)
.await;

// read the 10 closest results from victor that are tagged with "tags"
// (only 2 will be returned because we only inserted two embeddings)
let nearest = victor
.find_nearest_neighbors(vec![0.9, 0.0, 0.0], vec!["Test".to_string()], 10)
.search(vec!["Hawaiian pizza".to_string()], 10)
.await
.first()
.unwrap()
.content
.clone();
assert_eq!(nearest, "Test Vector 1".to_string());
assert_eq!(nearest, "Pineapple".to_string());
```

This example is also in the `/examples` directory. If you've cloned this repository, you can run it with `cargo run --example native_filesystem`.
Expand All @@ -100,7 +94,7 @@ This example is also in the `/examples` directory. If you've cloned this reposit
**Install wasm** pack with `cargo install wasm-pack` or `npm i -g wasm-pack`
(https://rustwasm.github.io/wasm-pack/installer/)

2. **Build Victor** with `wasm-pack build`
2. **Build Victor** with `wasm-pack build --target web`

3. **Set up the example project**, which is in `www/`.

Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions examples/native_filesystem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ async fn main() {
victor.clear_db().await.unwrap();

victor
.write(
.add_embedding(
"Test Vector 1",
vec![1.0, 0.0, 0.0],
vec!["Test".to_string()],
)
.await;
victor
.write(
.add_embedding(
"Test Vector 2",
vec![0.0, 1.0, 0.0],
vec!["Test".to_string()],
Expand All @@ -27,7 +27,7 @@ async fn main() {
// read the 10 closest results from victor that are tagged with "tags"
// (only 2 will be returned because we only inserted two embeddings)
let nearest = victor
.find_nearest_neighbors(vec![0.9, 0.0, 0.0], vec!["Test".to_string()], 10)
.search_embedding(vec![0.9, 0.0, 0.0], vec!["Test".to_string()], 10)
.await
.first()
.unwrap()
Expand Down
2 changes: 2 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[toolchain]
channel = "1.82.0"
49 changes: 46 additions & 3 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,32 @@ impl<D: DirectoryHandle> Victor<D> {
Self { root }
}

pub async fn write(&mut self, content: impl Into<String>, vector: Vec<f32>, tags: Vec<String>) {
#[cfg(not(target_arch = "wasm32"))]
pub async fn add_many(&mut self, content: Vec<impl Into<String>>, tags: Vec<String>) {
let model = fastembed::TextEmbedding::try_new(Default::default()).unwrap();
let content = content
.into_iter()
.map(|c| c.into())
.collect::<Vec<String>>();

let vectors = model.embed(content.clone(), None).unwrap();
for (vector, content) in vectors.iter().zip(content.iter()) {
self.add_embedding(content.clone(), vector.clone(), tags.clone())
.await;
}
}

#[cfg(not(target_arch = "wasm32"))]
pub async fn add(&mut self, content: impl Into<String>, tags: Vec<String>) {
self.add_many(vec![content], tags).await;
}

pub async fn add_embedding(
&mut self,
content: impl Into<String>,
vector: Vec<f32>,
tags: Vec<String>,
) {
let content = content.into();

let id = Uuid::new_v4();
Expand All @@ -88,7 +113,25 @@ impl<D: DirectoryHandle> Victor<D> {
self.write_content(content, id).await.unwrap();
}

pub async fn find_nearest_neighbors(
#[cfg(not(target_arch = "wasm32"))]
pub async fn search(
&self,
content: impl Into<String>,
with_tags: Vec<String>,
top_n: u32,
) -> Vec<NearestNeighborsResult> {
let model = fastembed::TextEmbedding::try_new(Default::default()).unwrap();
let content = content.into();
let vector = model
.embed(vec![content.clone()], None)
.unwrap()
.first()
.cloned()
.unwrap();
self.search_embedding(vector, with_tags, top_n).await
}

pub async fn search_embedding(
&self,
mut vector: Vec<f32>,
with_tags: Vec<String>,
Expand Down Expand Up @@ -573,7 +616,7 @@ impl Eq for NearestNeighborsResult {}

impl PartialOrd for NearestNeighborsResult {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.similarity.partial_cmp(&other.similarity)
Some(self.cmp(other))
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/filesystem/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl filesystem::WritableFileStream for WritableFileStream {
.collect::<Vec<u8>>();

self.cursor_pos += data_len;

Ok(())
}

Expand Down
18 changes: 6 additions & 12 deletions src/filesystem/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,10 @@ impl filesystem::DirectoryHandle for DirectoryHandle {
name: &str,
options: &filesystem::GetFileHandleOptions,
) -> Result<Self::FileHandleT, Self::Error> {
let fs_options = FileSystemGetFileOptions::new();
fs_options.set_create(options.create);
let file_system_file_handle = FileSystemFileHandle::from(
JsFuture::from(self.0.get_file_handle_with_options(
name,
FileSystemGetFileOptions::new().create(options.create),
))
.await?,
JsFuture::from(self.0.get_file_handle_with_options(name, &fs_options)).await?,
);
Ok(FileHandle(file_system_file_handle))
}
Expand All @@ -80,14 +78,10 @@ impl filesystem::FileHandle for FileHandle {
&mut self,
options: &filesystem::CreateWritableOptions,
) -> Result<Self::WritableFileStreamT, Self::Error> {
let fs_options = FileSystemCreateWritableOptions::new();
fs_options.set_keep_existing_data(options.keep_existing_data);
let file_system_writable_file_stream = FileSystemWritableFileStream::unchecked_from_js(
JsFuture::from(
self.0.create_writable_with_options(
FileSystemCreateWritableOptions::new()
.keep_existing_data(options.keep_existing_data),
),
)
.await?,
JsFuture::from(self.0.create_writable_with_options(&fs_options)).await?,
);
Ok(WritableFileStream(file_system_writable_file_stream))
}
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl Db {
})
.unwrap_or(vec![]);

self.victor.write(content, embedding, tags).await;
self.victor.add_embedding(content, embedding, tags).await;
}

pub async fn search(
Expand All @@ -114,7 +114,7 @@ impl Db {

let nearest_neighbors = self
.victor
.find_nearest_neighbors(embedding, tags, top_n.unwrap_or(10.0) as u32)
.search_embedding(embedding, tags, top_n.unwrap_or(10.0) as u32)
.await;

serde_wasm_bindgen::to_value(&nearest_neighbors).unwrap()
Expand Down
Loading
Loading