Skip to content

Commit

Permalink
Link ort dynamically in Windows. (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan authored Aug 31, 2024
1 parent 3c2f9ca commit 335edf6
Show file tree
Hide file tree
Showing 8 changed files with 796 additions and 186 deletions.
935 changes: 770 additions & 165 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ itertools = "0.13"
log = "0.4"
memmap2 = "0.9"
safetensors = "0.4"
salvo = "0.71.1"
serde = { version = "1", features = ["derive"] }
tokio = { version = "1", features = ["full"] }

Expand Down
9 changes: 5 additions & 4 deletions assets/configs/Config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ secret_key = "ai00_is_good"
[web] # Remove this to disable WebUI.
path = "assets/www/index.zip" # Path to the WebUI.

# [embed] # Enable embed model, which is based on fast-embedding onnx models.
# endpoint = "https://hf-mirror.com"
# home = "./assets/models/hf"
# name = { MultilingualE5Small = {} }
[embed] # Enable embed model, which is based on fast-embedding onnx models.
endpoint = "https://hf-mirror.com"
home = "assets/models/hf"
lib = "assets/ort/onnxruntime.dll" # Only used under windows.
name = { MultilingualE5Small = {} }
Binary file added assets/ort/onnxruntime.dll
Binary file not shown.
6 changes: 2 additions & 4 deletions crates/ai00-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ cbor4ii = { version = "0.3.2", features = ["serde1"] }
fastrand = "2"
half = "2.4"
kbnf = "0.4.1"
voracious_radix_sort = "1.2.0"
qp-trie = "0.8"
rustc-hash = "2.0.0"
uuid = { version = "1.8.0", features = ["serde", "v4"] }
voracious_radix_sort = "1.2.0"

[dependencies.anyhow]
workspace = true
Expand Down Expand Up @@ -53,7 +53,5 @@ workspace = true
workspace = true

[dependencies.salvo]
# git = "https://github.com/salvo-rs/salvo"
default-features = false
features = ["oapi"]
version = "0.71.1"
workspace = true
17 changes: 10 additions & 7 deletions crates/ai00-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ sha2 = "0.10.8"
simple_logger = { version = "5.0.0", features = ["stderr"] }
tempfile = "3.6"
toml = "0.8.6"
zip-extract = "0.1"
zip-extract = "0.2"

[dependencies.fastembed]
optional = true
version = "3.14"
version = "4"

[target.'cfg(windows)'.dependencies.fastembed]
features = ["ort-load-dynamic"]
optional = true
version = "4"

[dependencies.hf-hub]
optional = true
Expand All @@ -39,11 +44,11 @@ version = "0.3"
[dependencies.text-splitter]
features = ["markdown", "tokenizers"]
optional = true
version = "0.14"
version = "0.15"

[dependencies.tokenizers]
optional = true
version = "0.19"
version = "0.20"

[dependencies.ai00-core]
workspace = true
Expand Down Expand Up @@ -79,8 +84,6 @@ workspace = true
workspace = true

[dependencies.salvo]
# git = "https://github.com/salvo-rs/salvo"
default-features = true
features = [
"acme",
"affix-state",
Expand All @@ -92,4 +95,4 @@ features = [
"serve-static",
"sse",
]
version = "0.71.1"
workspace = true
3 changes: 3 additions & 0 deletions crates/ai00-server/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ pub struct EmbedOption {
pub endpoint: String,
#[derivative(Default(value = "\"assets/models/hf\".into()"))]
pub home: PathBuf,
#[cfg(target_os = "windows")]
#[derivative(Default(value = "\"assets/ort/onnxruntime.dll\".into()"))]
pub lib: PathBuf,
}

#[derive(Debug, Default, Clone, Serialize, Deserialize)]
Expand Down
11 changes: 5 additions & 6 deletions crates/ai00-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,19 @@ fn load_embed(embed: config::EmbedOption) -> Result<TextEmbed> {

std::env::set_var("HF_ENDPOINT", embed.endpoint);
std::env::set_var("HF_HOME", embed.home);
#[cfg(target_os = "windows")]
std::env::set_var("ORT_DYLIB_PATH", embed.lib);

let api = Api::new()?;
let info = TextEmbedding::get_model_info(&embed.model);
let info = TextEmbedding::get_model_info(&embed.model)?.clone();

let file = api.model(info.model_code.clone()).get("tokenizer.json")?;
let tokenizer = tokenizers::Tokenizer::from_file(file).expect("failed to load tokenizer");

log::info!("loading embed model: {}", embed.model);

let model = TextEmbedding::try_new(InitOptions {
model_name: embed.model,
show_download_progress: true,
..Default::default()
})?;
let model =
TextEmbedding::try_new(InitOptions::new(embed.model).with_show_download_progress(true))?;

Ok(TextEmbed {
tokenizer,
Expand Down

0 comments on commit 335edf6

Please sign in to comment.