Skip to content

Commit

Permalink
Use fast embeddings to automatically generate embeddings (#44)
Browse files Browse the repository at this point in the history
* fast embeddings

* ci: Add GitHub Actions workflow for Rust and WASM builds

* format

* add rust-toolchain.toml

* clippy

* update readme

* fix webpack for newer rust versions

* fix tests

* nvm in ci

* fix: Update CopyWebpackPlugin configuration to latest API syntax

* update packages

* fix: Add syncWebAssembly support to webpack config for wasm loading

* update webpack
  • Loading branch information
anchpop authored Dec 26, 2024
1 parent 85e1fac commit 70e3a63
Show file tree
Hide file tree
Showing 15 changed files with 1,472 additions and 1,758 deletions.
64 changes: 64 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: CI

on:
push:
branches: [main]
pull_request:
branches: [main]

env:
CARGO_TERM_COLOR: always

jobs:
rust-checks:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install Rust
uses: dtolnay/rust-toolchain@stable
with:
components: clippy, rustfmt

- name: Cache dependencies
uses: actions/cache@v3
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}

- name: Check formatting
run: cargo fmt --all -- --check

- name: Run clippy
run: cargo clippy -- -D warnings

- name: Run tests
run: cargo test

- name: Check semver
uses: obi1kenobi/cargo-semver-checks-action@v2

wasm-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install Node.js
uses: actions/setup-node@v4
with:
node-version: "20"

- name: Install wasm-pack
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh

- name: Build wasm package
run: wasm-pack build --target web

- name: Install and build www
working-directory: www
run: |
npm install
npm run build
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ wasm/target
Cargo.lock
pkg/
node_modules/
victor_test_data/
victor_test_data/
.fastembed_cache/
.aider*
.env
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "victor-db"
version = "0.1.3"
version = "0.2.3"
authors = ["Sam Hall <s@muel.email", "Andre Popovitch <andre@popovit.ch>"]
edition = "2021"
license-file = "LICENSE.md"
Expand Down Expand Up @@ -52,6 +52,7 @@ js-sys = "0.3"

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
tokio = { version = "1", features = ["rt", "macros", "fs", "io-util"] }
fastembed = "4.3.0"

[dev-dependencies]
wasm-bindgen-test = "0.3"
Expand Down
22 changes: 8 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ cargo add victor-db

#### Usage

The Rust API can automatically create embeddings for you with [fastembed-rs](https://github.com/anush008/fastembed-rs?tab=readme-ov-file)'s default model (currently [BAAI/bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5)).

```rust
use std::path::PathBuf;

Expand All @@ -65,30 +67,22 @@ let mut victor = Db::new(PathBuf::from("./victor_test_data"));
victor.clear_db().await.unwrap();

victor
.write(
"Test Vector 1",
vec![1.0, 0.0, 0.0],
vec!["Test".to_string()],
)
.await;
victor
.write(
"Test Vector 2",
vec![0.0, 1.0, 0.0],
vec!["Test".to_string()],
.add_many(
vec!["Pinapple", "Rocks"], // documents
vec!["PizzaToppings"], // tags (only used for filtering)
)
.await;

// read the 10 closest results from victor that are tagged with "tags"
// (only 2 will be returned because we only inserted two embeddings)
let nearest = victor
.find_nearest_neighbors(vec![0.9, 0.0, 0.0], vec!["Test".to_string()], 10)
.search(vec!["Hawaiian pizza".to_string()], 10)
.await
.first()
.unwrap()
.content
.clone();
assert_eq!(nearest, "Test Vector 1".to_string());
assert_eq!(nearest, "Pineapple".to_string());
```

This example is also in the `/examples` directory. If you've cloned this repository, you can run it with `cargo run --example native_filesystem`.
Expand All @@ -100,7 +94,7 @@ This example is also in the `/examples` directory. If you've cloned this reposit
**Install wasm** pack with `cargo install wasm-pack` or `npm i -g wasm-pack`
(https://rustwasm.github.io/wasm-pack/installer/)

2. **Build Victor** with `wasm-pack build`
2. **Build Victor** with `wasm-pack build --target web`

3. **Set up the example project**, which is in `www/`.

Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions examples/native_filesystem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ async fn main() {
victor.clear_db().await.unwrap();

victor
.write(
.add_embedding(
"Test Vector 1",
vec![1.0, 0.0, 0.0],
vec!["Test".to_string()],
)
.await;
victor
.write(
.add_embedding(
"Test Vector 2",
vec![0.0, 1.0, 0.0],
vec!["Test".to_string()],
Expand All @@ -27,7 +27,7 @@ async fn main() {
// read the 10 closest results from victor that are tagged with "tags"
// (only 2 will be returned because we only inserted two embeddings)
let nearest = victor
.find_nearest_neighbors(vec![0.9, 0.0, 0.0], vec!["Test".to_string()], 10)
.search_embedding(vec![0.9, 0.0, 0.0], vec!["Test".to_string()], 10)
.await
.first()
.unwrap()
Expand Down
2 changes: 2 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[toolchain]
channel = "1.82.0"
49 changes: 46 additions & 3 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,32 @@ impl<D: DirectoryHandle> Victor<D> {
Self { root }
}

pub async fn write(&mut self, content: impl Into<String>, vector: Vec<f32>, tags: Vec<String>) {
#[cfg(not(target_arch = "wasm32"))]
pub async fn add_many(&mut self, content: Vec<impl Into<String>>, tags: Vec<String>) {
let model = fastembed::TextEmbedding::try_new(Default::default()).unwrap();
let content = content
.into_iter()
.map(|c| c.into())
.collect::<Vec<String>>();

let vectors = model.embed(content.clone(), None).unwrap();
for (vector, content) in vectors.iter().zip(content.iter()) {
self.add_embedding(content.clone(), vector.clone(), tags.clone())
.await;
}
}

#[cfg(not(target_arch = "wasm32"))]
pub async fn add(&mut self, content: impl Into<String>, tags: Vec<String>) {
self.add_many(vec![content], tags).await;
}

pub async fn add_embedding(
&mut self,
content: impl Into<String>,
vector: Vec<f32>,
tags: Vec<String>,
) {
let content = content.into();

let id = Uuid::new_v4();
Expand All @@ -88,7 +113,25 @@ impl<D: DirectoryHandle> Victor<D> {
self.write_content(content, id).await.unwrap();
}

pub async fn find_nearest_neighbors(
#[cfg(not(target_arch = "wasm32"))]
pub async fn search(
&self,
content: impl Into<String>,
with_tags: Vec<String>,
top_n: u32,
) -> Vec<NearestNeighborsResult> {
let model = fastembed::TextEmbedding::try_new(Default::default()).unwrap();
let content = content.into();
let vector = model
.embed(vec![content.clone()], None)
.unwrap()
.first()
.cloned()
.unwrap();
self.search_embedding(vector, with_tags, top_n).await
}

pub async fn search_embedding(
&self,
mut vector: Vec<f32>,
with_tags: Vec<String>,
Expand Down Expand Up @@ -573,7 +616,7 @@ impl Eq for NearestNeighborsResult {}

impl PartialOrd for NearestNeighborsResult {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.similarity.partial_cmp(&other.similarity)
Some(self.cmp(other))
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/filesystem/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl filesystem::WritableFileStream for WritableFileStream {
.collect::<Vec<u8>>();

self.cursor_pos += data_len;

Ok(())
}

Expand Down
18 changes: 6 additions & 12 deletions src/filesystem/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,10 @@ impl filesystem::DirectoryHandle for DirectoryHandle {
name: &str,
options: &filesystem::GetFileHandleOptions,
) -> Result<Self::FileHandleT, Self::Error> {
let fs_options = FileSystemGetFileOptions::new();
fs_options.set_create(options.create);
let file_system_file_handle = FileSystemFileHandle::from(
JsFuture::from(self.0.get_file_handle_with_options(
name,
FileSystemGetFileOptions::new().create(options.create),
))
.await?,
JsFuture::from(self.0.get_file_handle_with_options(name, &fs_options)).await?,
);
Ok(FileHandle(file_system_file_handle))
}
Expand All @@ -80,14 +78,10 @@ impl filesystem::FileHandle for FileHandle {
&mut self,
options: &filesystem::CreateWritableOptions,
) -> Result<Self::WritableFileStreamT, Self::Error> {
let fs_options = FileSystemCreateWritableOptions::new();
fs_options.set_keep_existing_data(options.keep_existing_data);
let file_system_writable_file_stream = FileSystemWritableFileStream::unchecked_from_js(
JsFuture::from(
self.0.create_writable_with_options(
FileSystemCreateWritableOptions::new()
.keep_existing_data(options.keep_existing_data),
),
)
.await?,
JsFuture::from(self.0.create_writable_with_options(&fs_options)).await?,
);
Ok(WritableFileStream(file_system_writable_file_stream))
}
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl Db {
})
.unwrap_or(vec![]);

self.victor.write(content, embedding, tags).await;
self.victor.add_embedding(content, embedding, tags).await;
}

pub async fn search(
Expand All @@ -114,7 +114,7 @@ impl Db {

let nearest_neighbors = self
.victor
.find_nearest_neighbors(embedding, tags, top_n.unwrap_or(10.0) as u32)
.search_embedding(embedding, tags, top_n.unwrap_or(10.0) as u32)
.await;

serde_wasm_bindgen::to_value(&nearest_neighbors).unwrap()
Expand Down
Loading

0 comments on commit 70e3a63

Please sign in to comment.