diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json deleted file mode 100644 index 10a1c1a..0000000 --- a/.devcontainer/devcontainer.json +++ /dev/null @@ -1,25 +0,0 @@ -// For format details, see https://aka.ms/devcontainer.json. For config options, see the -// README at: https://github.com/devcontainers/templates/tree/main/src/ubuntu -{ - "name": "Ubuntu", - // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile - "image": "mcr.microsoft.com/devcontainers/base:focal", - "features": { - "ghcr.io/devcontainers/features/rust:1": {} - } - - // Features to add to the dev container. More info: https://containers.dev/features. - // "features": {}, - - // Use 'forwardPorts' to make a list of ports inside the container available locally. - // "forwardPorts": [], - - // Use 'postCreateCommand' to run commands after the container is created. - // "postCreateCommand": "uname -a", - - // Configure tool-specific properties. - // "customizations": {}, - - // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. - // "remoteUser": "root" -} diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 0aee45b..db57260 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -32,8 +32,8 @@ jobs: - name: Install WasmEdge run: | - VERSION=0.13.4 - curl -sSf https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | sudo bash -s -- -e all --version=$VERSION --plugins=wasi_nn-tensorflowlite --plugins=wasi_crypto -p /usr/local + VERSION=0.13.5 + curl -sSf https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | sudo bash -s -- -e all --version=$VERSION --plugins wasi_nn-tensorflowlite wasi_crypto -p /usr/local wget https://github.com/WasmEdge/WasmEdge/releases/download/$VERSION/WasmEdge-plugin-wasmedge_rustls-$VERSION-ubuntu20.04_x86_64.tar.gz sudo chmod +x /usr/local/lib/wasmedge diff --git a/Cargo.lock b/Cargo.lock index a04a4a3..7411274 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,6 +65,18 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "bitflags" version = "1.3.2" @@ -73,15 +85,15 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.1" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" [[package]] name = "bytemuck" -version = "1.14.0" +version = "1.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +checksum = "ed2490600f404f2b94c167e31d3ed1d5f3c225a0f3b80230053b3e0b7b962bd9" [[package]] name = "byteorder" @@ -110,12 +122,29 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chat-prompts" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b01b39dd6fff99d78eaef963e1b2076754a81850417356fa623bdee253eebbd" +dependencies = [ + "endpoints", + "enum_dispatch", + "thiserror", +] + [[package]] name = "color_quant" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "conv" version = "0.3.3" @@ -136,33 +165,38 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fca89a0e215bab21874660c67903c5f143333cab1da83d041c7ded6053774751" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" dependencies = [ - "cfg-if", "crossbeam-epoch", "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" -version = "0.9.17" +version = "0.9.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e3681d554572a651dda4186cd47240627c3d0114d45a95f6ad27f2f22e7548d" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" dependencies = [ - "autocfg", - "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.18" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3a430a770ebd84726f584a90ee7f020d28db52c6d02138900f22341f866d39c" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + +[[package]] +name = "crypto-wasi" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac0fad491b70b319705c5564b53ee302299f2654972119330a9e1f62f4baa303" dependencies = [ - "cfg-if", + "base64 0.21.7", + "der", + "pem", ] [[package]] @@ -181,6 +215,16 @@ dependencies = [ "byteorder", ] +[[package]] +name = "der" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" +dependencies = [ + "const-oid", + "zeroize", +] + [[package]] name = "dns-parser" version = "0.8.0" @@ -261,11 +305,32 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a246d82be1c9d791c5dfde9a2bd045fc3cbba3fa2b11ad558f27d01712f00569" +[[package]] +name = "endpoints" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "844f278cb4e558c0cb2f920bbb221884bfa35cace9e61d3e238fc01ec09ca8a5" +dependencies = [ + "serde", +] + +[[package]] +name = "enum_dispatch" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f33313078bb8d4d05a2733a94ac4c2d8a0df9a2b84424ebf4f33bfc224a890e" +dependencies = [ + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "env_logger" -version = "0.10.1" +version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95b3f3e67048839cb0d0781f445682a35113da7121f7c949db0e2be96a4fbece" +checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" dependencies = [ "humantime", "is-terminal", @@ -301,14 +366,14 @@ checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.9.0+wasi-snapshot-preview1", ] [[package]] name = "getrandom" -version = "0.2.11" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", "libc", @@ -323,9 +388,9 @@ checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "hermit-abi" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +checksum = "5d3d0e0f38255e7fa3cf31335b3a56f05febd18025f4db5ef7a0cfb4f8da651f" [[package]] name = "humantime" @@ -378,13 +443,13 @@ dependencies = [ [[package]] name = "is-terminal" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" +checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" dependencies = [ "hermit-abi", "rustix", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -410,15 +475,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.151" +version = "0.2.152" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" +checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" [[package]] name = "linux-raw-sys" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "lock_api" @@ -447,9 +512,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.6.4" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" [[package]] name = "miniz_oxide" @@ -588,6 +653,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + [[package]] name = "owned_ttf_parser" version = "0.15.2" @@ -620,6 +691,15 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "pem" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8835c273a76a90455d7344889b0964598e3316e2a79ede8e36f16bdcf2228b8" +dependencies = [ + "base64 0.13.1", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -652,9 +732,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.71" +version = "1.0.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" dependencies = [ "unicode-ident", ] @@ -667,9 +747,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -733,7 +813,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.11", + "getrandom 0.2.12", ] [[package]] @@ -762,9 +842,9 @@ checksum = "ebac11a9d2e11f2af219b8b8d833b76b1ea0e054aa0e8d8e9e4cbde353bdf019" [[package]] name = "rayon" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" dependencies = [ "either", "rayon-core", @@ -772,9 +852,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -791,9 +871,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.2" +version = "1.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" dependencies = [ "aho-corasick", "memchr", @@ -803,9 +883,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.3" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" dependencies = [ "aho-corasick", "memchr", @@ -825,7 +905,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" dependencies = [ "cc", - "getrandom 0.2.11", + "getrandom 0.2.12", "libc", "spin", "untrusted", @@ -850,11 +930,11 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustix" -version = "0.38.28" +version = "0.38.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" +checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "errno", "libc", "linux-raw-sys", @@ -909,6 +989,26 @@ dependencies = [ "untrusted", ] +[[package]] +name = "serde" +version = "1.0.196" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.196" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -920,9 +1020,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.2" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" [[package]] name = "socket2" @@ -951,15 +1051,46 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "termcolor" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff1bc3d3f05aff0403e8ac0d92ced918ec05b666a43f83297ccef5bea8a3d449" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" dependencies = [ "winapi-util", ] +[[package]] +name = "thiserror" +version = "1.0.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -993,7 +1124,7 @@ checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -1036,9 +1167,9 @@ checksum = "7b3e06c9b9d80ed6b745c7159c40b311ad2916abb34a49e9be2653b90db0d8dd" [[package]] name = "unicode-bidi" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f2528f27a9eb2b21e69c95319b30bd0efd85d09c379741b0f78ea1d86be2416" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" [[package]] name = "unicode-ident" @@ -1078,12 +1209,29 @@ version = "0.9.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasi-nn" +version = "0.6.0" +source = "git+https://github.com/second-state/wasmedge-wasi-nn?branch=ggml#891f7c414bf1eecaa1b36c5792d1c88097ceafd6" +dependencies = [ + "thiserror", +] + [[package]] name = "wasmedge_quickjs" version = "0.5.0-alpha" dependencies = [ "argparse", + "chat-prompts", + "crypto-wasi", "encoding", + "endpoints", "env_logger", "image", "imageproc", @@ -1094,6 +1242,7 @@ dependencies = [ "tokio-rustls-wasi", "tokio_wasi", "url", + "wasi-nn", "wasmedge_wasi_socket", "webpki-roots", ] @@ -1344,3 +1493,9 @@ name = "windows_x86_64_msvc" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" diff --git a/Cargo.toml b/Cargo.toml index cdf51ab..f5667a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,11 +32,16 @@ tokio-rustls-wasi = { version = "0.24.1", optional = true } webpki-roots = { version = "0.25.0", optional = true } crypto-wasi = { version = "0.1.1", optional = true } +chat-prompts = { version = "0.3", optional = true } +wasi-nn = { git = "https://github.com/second-state/wasmedge-wasi-nn", branch = "ggml", optional = true } +endpoints = { version = "0.2", optional = true } + [features] default = ["tls"] tls = ["rustls", "tokio-rustls-wasi", "webpki-roots"] img = ["image", "imageproc"] tensorflow = ["img"] wasi_nn = ["img"] +ggml = ["chat-prompts", "dep:wasi-nn", "endpoints"] cjs = [] nodejs_crypto = ["crypto-wasi"] diff --git a/example_js/ggml_chat.js b/example_js/ggml_chat.js new file mode 100644 index 0000000..f2d707f --- /dev/null +++ b/example_js/ggml_chat.js @@ -0,0 +1,57 @@ +import { GGMLChatCompletionRequest, GGMLChatPrompt } from '_wasi_nn_ggml_template' +import { build_graph_from_cache } from '_wasi_nn_ggml' +function main() { + let opt = { + "enable-log": true, + "ctx_size": 512, + "n-predict": 1024, + "n-gpu-layers": 100, + "batch-size": 512, + "temp": 0.8, + "repeat-penalty": 1.1 + } + + let graph = build_graph_from_cache(3, JSON.stringify(opt), "default") + let context = graph.init_execution_context() + + let template = new GGMLChatPrompt('llama-2-chat') + + let req = new GGMLChatCompletionRequest() + + let messages = ['hello', 'who are you?'] + + for (var i in messages) { + print("[YOU]:", messages[i]) + req.push_message("user", messages[i]) + let p = template.build(req) + context.set_input(0, p, [1], 3) + var ss = ''; + + while (1) { + try { + context.compute_single() + let s = context.get_output_single(0, 1) + ss += s; + print('BOT:', s) + } catch (e) { + if (e['type'] == "BackendError" && e['message'] == "EndOfSequence") { + print('[log] EndOfSequence!') + break + } else if (e['type'] == "BackendError" && e['message'] == "ContextFull") { + print('[log] ContextFull!') + break + } else { + return + } + } + } + req.push_message("assistant", ss) + print("[BOT]:", ss); + } + + let p = template.build(req); + print() + print(p) +} + +main() \ No newline at end of file diff --git a/lib/binding.rs b/lib/binding.rs index 655d20d..83cdd45 100644 --- a/lib/binding.rs +++ b/lib/binding.rs @@ -1,4 +1,4 @@ -/* automatically generated by rust-bindgen 0.66.1 */ +/* automatically generated by rust-bindgen 0.68.1 */ pub const JS_PROP_CONFIGURABLE: u32 = 1; pub const JS_PROP_WRITABLE: u32 = 2; @@ -31,6 +31,7 @@ pub const JS_EVAL_FLAG_STRICT: u32 = 8; pub const JS_EVAL_FLAG_STRIP: u32 = 16; pub const JS_EVAL_FLAG_COMPILE_ONLY: u32 = 32; pub const JS_EVAL_FLAG_BACKTRACE_BARRIER: u32 = 64; +pub const JS_EVAL_FLAG_ASYNC: u32 = 128; pub const JS_ATOM_NULL: u32 = 0; pub const JS_CALL_FLAG_CONSTRUCTOR: u32 = 1; pub const JS_GPN_STRING_MASK: u32 = 1; @@ -57,7 +58,6 @@ pub const JS_DEF_PROP_DOUBLE: u32 = 6; pub const JS_DEF_PROP_UNDEFINED: u32 = 7; pub const JS_DEF_OBJECT: u32 = 8; pub const JS_DEF_ALIAS: u32 = 9; - #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct JSRuntime { @@ -80,24 +80,24 @@ pub struct JSClass { } pub type JSClassID = u32; pub type JSAtom = u32; -pub const JS_TAG_FIRST: _bindgen_ty_1 = -11; -pub const JS_TAG_BIG_DECIMAL: _bindgen_ty_1 = -11; -pub const JS_TAG_BIG_INT: _bindgen_ty_1 = -10; -pub const JS_TAG_BIG_FLOAT: _bindgen_ty_1 = -9; -pub const JS_TAG_SYMBOL: _bindgen_ty_1 = -8; -pub const JS_TAG_STRING: _bindgen_ty_1 = -7; -pub const JS_TAG_MODULE: _bindgen_ty_1 = -3; -pub const JS_TAG_FUNCTION_BYTECODE: _bindgen_ty_1 = -2; -pub const JS_TAG_OBJECT: _bindgen_ty_1 = -1; -pub const JS_TAG_INT: _bindgen_ty_1 = 0; -pub const JS_TAG_BOOL: _bindgen_ty_1 = 1; -pub const JS_TAG_NULL: _bindgen_ty_1 = 2; -pub const JS_TAG_UNDEFINED: _bindgen_ty_1 = 3; -pub const JS_TAG_UNINITIALIZED: _bindgen_ty_1 = 4; -pub const JS_TAG_CATCH_OFFSET: _bindgen_ty_1 = 5; -pub const JS_TAG_EXCEPTION: _bindgen_ty_1 = 6; -pub const JS_TAG_FLOAT64: _bindgen_ty_1 = 7; -pub type _bindgen_ty_1 = ::std::os::raw::c_int; +pub const JS_TAG_JS_TAG_FIRST: JS_TAG = -11; +pub const JS_TAG_JS_TAG_BIG_DECIMAL: JS_TAG = -11; +pub const JS_TAG_JS_TAG_BIG_INT: JS_TAG = -10; +pub const JS_TAG_JS_TAG_BIG_FLOAT: JS_TAG = -9; +pub const JS_TAG_JS_TAG_SYMBOL: JS_TAG = -8; +pub const JS_TAG_JS_TAG_STRING: JS_TAG = -7; +pub const JS_TAG_JS_TAG_MODULE: JS_TAG = -3; +pub const JS_TAG_JS_TAG_FUNCTION_BYTECODE: JS_TAG = -2; +pub const JS_TAG_JS_TAG_OBJECT: JS_TAG = -1; +pub const JS_TAG_JS_TAG_INT: JS_TAG = 0; +pub const JS_TAG_JS_TAG_BOOL: JS_TAG = 1; +pub const JS_TAG_JS_TAG_NULL: JS_TAG = 2; +pub const JS_TAG_JS_TAG_UNDEFINED: JS_TAG = 3; +pub const JS_TAG_JS_TAG_UNINITIALIZED: JS_TAG = 4; +pub const JS_TAG_JS_TAG_CATCH_OFFSET: JS_TAG = 5; +pub const JS_TAG_JS_TAG_EXCEPTION: JS_TAG = 6; +pub const JS_TAG_JS_TAG_FLOAT64: JS_TAG = 7; +pub type JS_TAG = ::std::os::raw::c_int; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct JSRefCountHeader { @@ -382,7 +382,16 @@ pub struct JSMemoryUsage { pub binary_object_count: i64, pub binary_object_size: i64, } - +extern "C" { + pub fn JS_ComputeMemoryUsage(rt: *mut JSRuntime, s: *mut JSMemoryUsage); +} +extern "C" { + pub fn JS_DumpMemoryUsage( + fp: *mut ::std::os::raw::c_int, + s: *const JSMemoryUsage, + rt: *mut JSRuntime, + ); +} extern "C" { pub fn JS_NewAtomLen( ctx: *mut JSContext, @@ -710,9 +719,10 @@ extern "C" { extern "C" { pub fn JS_SetPropertyInternal( ctx: *mut JSContext, - this_obj: JSValue, + obj: JSValue, prop: JSAtom, val: JSValue, + this_obj: JSValue, flags: ::std::os::raw::c_int, ) -> ::std::os::raw::c_int; } @@ -1002,9 +1012,19 @@ extern "C" { sf: *const JSSharedArrayBufferFunctions, ); } +pub const JSPromiseStateEnum_JS_PROMISE_PENDING: JSPromiseStateEnum = 0; +pub const JSPromiseStateEnum_JS_PROMISE_FULFILLED: JSPromiseStateEnum = 1; +pub const JSPromiseStateEnum_JS_PROMISE_REJECTED: JSPromiseStateEnum = 2; +pub type JSPromiseStateEnum = ::std::os::raw::c_uint; extern "C" { pub fn JS_NewPromiseCapability(ctx: *mut JSContext, resolving_funcs: *mut JSValue) -> JSValue; } +extern "C" { + pub fn JS_PromiseState(ctx: *mut JSContext, promise: JSValue) -> JSPromiseStateEnum; +} +extern "C" { + pub fn JS_PromiseResult(ctx: *mut JSContext, promise: JSValue) -> JSValue; +} pub type JSHostPromiseRejectionTracker = ::std::option::Option< unsafe extern "C" fn( ctx: *mut JSContext, @@ -1074,6 +1094,9 @@ extern "C" { extern "C" { pub fn JS_GetModuleName(ctx: *mut JSContext, m: *mut JSModuleDef) -> JSAtom; } +extern "C" { + pub fn JS_GetModuleNamespace(ctx: *mut JSContext, m: *mut JSModuleDef) -> JSValue; +} pub type JSJobFunc = ::std::option::Option< unsafe extern "C" fn( ctx: *mut JSContext, @@ -1137,25 +1160,25 @@ extern "C" { ) -> JSAtom; } extern "C" { - pub fn JS_RunModule( + pub fn JS_LoadModule( ctx: *mut JSContext, basename: *const ::std::os::raw::c_char, filename: *const ::std::os::raw::c_char, - ) -> *mut JSModuleDef; + ) -> JSValue; } -pub const JS_CFUNC_generic: JSCFunctionEnum = 0; -pub const JS_CFUNC_generic_magic: JSCFunctionEnum = 1; -pub const JS_CFUNC_constructor: JSCFunctionEnum = 2; -pub const JS_CFUNC_constructor_magic: JSCFunctionEnum = 3; -pub const JS_CFUNC_constructor_or_func: JSCFunctionEnum = 4; -pub const JS_CFUNC_constructor_or_func_magic: JSCFunctionEnum = 5; -pub const JS_CFUNC_f_f: JSCFunctionEnum = 6; -pub const JS_CFUNC_f_f_f: JSCFunctionEnum = 7; -pub const JS_CFUNC_getter: JSCFunctionEnum = 8; -pub const JS_CFUNC_setter: JSCFunctionEnum = 9; -pub const JS_CFUNC_getter_magic: JSCFunctionEnum = 10; -pub const JS_CFUNC_setter_magic: JSCFunctionEnum = 11; -pub const JS_CFUNC_iterator_next: JSCFunctionEnum = 12; +pub const JSCFunctionEnum_JS_CFUNC_generic: JSCFunctionEnum = 0; +pub const JSCFunctionEnum_JS_CFUNC_generic_magic: JSCFunctionEnum = 1; +pub const JSCFunctionEnum_JS_CFUNC_constructor: JSCFunctionEnum = 2; +pub const JSCFunctionEnum_JS_CFUNC_constructor_magic: JSCFunctionEnum = 3; +pub const JSCFunctionEnum_JS_CFUNC_constructor_or_func: JSCFunctionEnum = 4; +pub const JSCFunctionEnum_JS_CFUNC_constructor_or_func_magic: JSCFunctionEnum = 5; +pub const JSCFunctionEnum_JS_CFUNC_f_f: JSCFunctionEnum = 6; +pub const JSCFunctionEnum_JS_CFUNC_f_f_f: JSCFunctionEnum = 7; +pub const JSCFunctionEnum_JS_CFUNC_getter: JSCFunctionEnum = 8; +pub const JSCFunctionEnum_JS_CFUNC_setter: JSCFunctionEnum = 9; +pub const JSCFunctionEnum_JS_CFUNC_getter_magic: JSCFunctionEnum = 10; +pub const JSCFunctionEnum_JS_CFUNC_setter_magic: JSCFunctionEnum = 11; +pub const JSCFunctionEnum_JS_CFUNC_iterator_next: JSCFunctionEnum = 12; pub type JSCFunctionEnum = ::std::os::raw::c_uint; #[repr(C)] #[derive(Copy, Clone)] @@ -1215,11 +1238,6 @@ pub union JSCFunctionType { ) -> JSValue, >, } -impl ::std::fmt::Debug for JSCFunctionType { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!(f, "JSCFunctionType {{ union }}") - } -} extern "C" { pub fn JS_NewCFunction2( ctx: *mut JSContext, @@ -1271,30 +1289,12 @@ pub struct JSCFunctionListEntry__bindgen_ty_1__bindgen_ty_1 { pub cproto: u8, pub cfunc: JSCFunctionType, } -impl ::std::fmt::Debug for JSCFunctionListEntry__bindgen_ty_1__bindgen_ty_1 { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!( - f, - "JSCFunctionListEntry__bindgen_ty_1__bindgen_ty_1 {{ cfunc: {:?} }}", - self.cfunc - ) - } -} #[repr(C)] #[derive(Copy, Clone)] pub struct JSCFunctionListEntry__bindgen_ty_1__bindgen_ty_2 { pub get: JSCFunctionType, pub set: JSCFunctionType, } -impl ::std::fmt::Debug for JSCFunctionListEntry__bindgen_ty_1__bindgen_ty_2 { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!( - f, - "JSCFunctionListEntry__bindgen_ty_1__bindgen_ty_2 {{ get: {:?}, set: {:?} }}", - self.get, self.set - ) - } -} #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct JSCFunctionListEntry__bindgen_ty_1__bindgen_ty_3 { @@ -1307,20 +1307,6 @@ pub struct JSCFunctionListEntry__bindgen_ty_1__bindgen_ty_4 { pub tab: *const JSCFunctionListEntry, pub len: ::std::os::raw::c_int, } -impl ::std::fmt::Debug for JSCFunctionListEntry__bindgen_ty_1 { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!(f, "JSCFunctionListEntry__bindgen_ty_1 {{ union }}") - } -} -impl ::std::fmt::Debug for JSCFunctionListEntry { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!( - f, - "JSCFunctionListEntry {{ name: {:?}, u: {:?} }}", - self.name, self.u - ) - } -} extern "C" { pub fn JS_SetPropertyFunctionList( ctx: *mut JSContext, @@ -1517,9 +1503,6 @@ extern "C" { extern "C" { pub fn JS_GetPromiseResult_real(ctx: *mut JSContext, this_val: JSValue) -> JSValue; } -extern "C" { - pub fn JS_GetPromiseState(ctx: *mut JSContext, this_val: JSValue) -> ::std::os::raw::c_int; -} extern "C" { pub fn JS_ToUint32_real( ctx: *mut JSContext, @@ -1562,9 +1545,6 @@ extern "C" { eval_flags: ::std::os::raw::c_int, ) -> ::std::os::raw::c_int; } -extern "C" { - pub fn js_require(ctx: *mut JSContext, specifier: JSValue) -> JSValue; -} extern "C" { pub fn js_undefined() -> JSValue; } diff --git a/lib/libquickjs.a b/lib/libquickjs.a index 208194f..f6c59e5 100644 Binary files a/lib/libquickjs.a and b/lib/libquickjs.a differ diff --git a/modules/process.js b/modules/process.js index 92ae6b2..73f304d 100644 --- a/modules/process.js +++ b/modules/process.js @@ -6,7 +6,7 @@ var title = 'wasmedge_quickjs'; var arch = 'wasm'; var platform = 'wasi'; var env = globalThis.env; -var argv = globalThis.argv; +var argv = globalThis.argv || globalThis.args; var execArgv = []; var version = 'v16.8.0'; var versions = {}; diff --git a/src/internal_module/ggml/mod.rs b/src/internal_module/ggml/mod.rs new file mode 100644 index 0000000..05b8e19 --- /dev/null +++ b/src/internal_module/ggml/mod.rs @@ -0,0 +1,576 @@ +use std::str::FromStr; + +use chat_prompts::{ + chat::{BuildChatPrompt, ChatPrompt}, + PromptTemplateType, +}; +use endpoints::chat::{ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole}; +use wasi_nn::BackendError; + +use crate::{ + register_class, AsObject, Context, JsClassDef, JsClassTool, JsModuleDef, JsObject, JsValue, + SelfRefJsValue, +}; + +struct WasiNNGraph(wasi_nn::Graph); + +impl JsClassDef for WasiNNGraph { + type RefType = WasiNNGraph; + + const CLASS_NAME: &'static str = "Graph"; + + const CONSTRUCTOR_ARGC: u8 = 0; + + const FIELDS: &'static [crate::JsClassField] = &[]; + + const METHODS: &'static [crate::JsClassMethod] = + &[("init_execution_context", 0, Self::js_init_execution_context)]; + + unsafe fn mut_class_id_ptr() -> &'static mut u32 { + static mut CLASS_ID: u32 = 0; + &mut CLASS_ID + } + + fn constructor_fn( + _ctx: &mut crate::Context, + _argv: &[JsValue], + ) -> Result { + Err(JsValue::UnDefined) + } +} + +impl WasiNNGraph { + pub fn js_init_execution_context( + &mut self, + this: &mut JsObject, + js_ctx: &mut Context, + _argv: &[JsValue], + ) -> JsValue { + let r = Self::self_ref_opaque_mut(this.clone().into(), |v| v.0.init_execution_context()); + match r { + None => JsValue::UnDefined, + Some(Ok(ctx)) => { + WasiNNGraphExecutionContext::wrap_obj(js_ctx, WasiNNGraphExecutionContext { ctx }) + } + Some(Err(e)) => { + let err = ggml_error_to_js_error(js_ctx, e); + js_ctx.throw_error(err).into() + } + } + } +} + +struct WasiNNGraphExecutionContext { + ctx: SelfRefJsValue>, +} + +impl JsClassDef for WasiNNGraphExecutionContext { + type RefType = Self; + + const CLASS_NAME: &'static str = "GraphExecutionContext"; + + const CONSTRUCTOR_ARGC: u8 = 0; + + const FIELDS: &'static [crate::JsClassField] = &[]; + + const METHODS: &'static [crate::JsClassMethod] = &[ + ("set_input", 4, Self::js_set_input), + ("compute", 0, Self::js_compute), + ("compute_single", 0, Self::js_compute_single), + ("fini_single", 0, Self::js_fini_single), + ("get_output", 2, Self::js_get_output), + ("get_output_single", 2, Self::js_get_output_single), + ]; + + unsafe fn mut_class_id_ptr() -> &'static mut u32 { + static mut CLASS_ID: u32 = 0; + &mut CLASS_ID + } + + fn constructor_fn(_ctx: &mut Context, _argv: &[JsValue]) -> Result { + Err(JsValue::UnDefined) + } +} + +lazy_static::lazy_static! { + static ref MAX_OUTPUT_SIZE: usize ={ + std::env::var("GGML_OUTPUT_BUFF_SIZE") + .unwrap_or_default() + .parse() + .unwrap_or(1024) + }; +} + +fn ggml_error_to_js_error(ctx: &mut Context, error: wasi_nn::Error) -> JsValue { + let (t, msg) = match error { + wasi_nn::Error::IoError(e) => { + let mut js_err = ctx.new_error(e.to_string().as_str()); + if let JsValue::Object(js_err) = &mut js_err { + js_err.set("type", ctx.new_string("IO").into()); + }; + return js_err; + } + wasi_nn::Error::BackendError(BackendError::InvalidArgument) => { + ("BackendError", "InvalidArgument") + } + wasi_nn::Error::BackendError(BackendError::InvalidEncoding) => { + ("BackendError", "InvalidEncoding") + } + wasi_nn::Error::BackendError(BackendError::MissingMemory) => { + ("BackendError", "MissingMemory") + } + wasi_nn::Error::BackendError(BackendError::Busy) => ("BackendError", "Busy"), + wasi_nn::Error::BackendError(BackendError::RuntimeError) => { + ("BackendError", "RuntimeError") + } + wasi_nn::Error::BackendError(BackendError::UnsupportedOperation) => { + ("BackendError", "UnsupportedOperation") + } + wasi_nn::Error::BackendError(BackendError::TooLarge) => ("BackendError", "TooLarge"), + wasi_nn::Error::BackendError(BackendError::NotFound) => ("BackendError", "NotFound"), + wasi_nn::Error::BackendError(BackendError::EndOfSequence) => { + ("BackendError", "EndOfSequence") + } + wasi_nn::Error::BackendError(BackendError::ContextFull) => ("BackendError", "ContextFull"), + wasi_nn::Error::BackendError(BackendError::PromptTooLong) => { + ("BackendError", "PromptTooLong") + } + wasi_nn::Error::BackendError(BackendError::UnknownError(i)) => { + let mut js_err = ctx.new_error(format!("UnknownError:{i}").as_str()); + if let JsValue::Object(js_err) = &mut js_err { + js_err.set("type", ctx.new_string("BackendError").into()); + }; + return js_err; + } + }; + let mut js_err = ctx.new_error(msg); + if let JsValue::Object(js_err) = &mut js_err { + js_err.set("type", ctx.new_string(t).into()); + }; + js_err +} + +impl WasiNNGraphExecutionContext { + fn js_set_input( + &mut self, + _this_obj: &mut JsObject, + ctx: &mut Context, + argv: &[JsValue], + ) -> JsValue { + let index = if let Some(JsValue::Int(index)) = argv.get(0) { + *index as usize + } else { + return ctx.throw_type_error("'index' must be of type int").into(); + }; + + let tensor_buf = match argv.get(1) { + Some(JsValue::ArrayBuffer(buf)) => buf.as_ref(), + Some(JsValue::String(s)) => s.as_str().trim().as_bytes(), + _ => { + return ctx + .throw_type_error("'tensor_buf' must be of type buffer or string") + .into(); + } + }; + + let dimensions = if let Some(JsValue::Array(arr)) = argv.get(2) { + match arr.to_vec() { + Ok(dimensions) => { + let mut dimension_arr = Vec::with_capacity(dimensions.len()); + + for i in dimensions { + let v = match i { + JsValue::Int(i) => i as usize, + JsValue::Float(i) => i as usize, + _ => { + return ctx + .throw_type_error("'dimensions' must be of type number array") + .into() + } + }; + dimension_arr.push(v); + } + dimension_arr + } + Err(e) => return e.into(), + } + } else { + return ctx + .throw_type_error("'dimensions' must be of type array") + .into(); + }; + + let tensor_type = if let Some(JsValue::Int(input_type)) = argv.get(3) { + let input_type = *input_type; + match input_type { + 0 => wasi_nn::TensorType::F16, + 1 => wasi_nn::TensorType::F32, + 2 => wasi_nn::TensorType::F64, + 3 => wasi_nn::TensorType::U8, + 4 => wasi_nn::TensorType::I32, + 5 => wasi_nn::TensorType::I64, + + _ => { + return ctx + .throw_type_error(&format!("undefined `input_type` {}", input_type)) + .into(); + } + } + } else { + return ctx.throw_type_error("'index' must be of type int").into(); + }; + + if let Err(e) = self + .ctx + .set_input(index, tensor_type, &dimensions, tensor_buf) + { + let err = ggml_error_to_js_error(ctx, e); + ctx.throw_error(err).into() + } else { + JsValue::UnDefined + } + } + + fn js_compute( + &mut self, + _this_obj: &mut JsObject, + ctx: &mut Context, + _argv: &[JsValue], + ) -> JsValue { + if let Err(e) = self.ctx.compute() { + let err = ggml_error_to_js_error(ctx, e); + ctx.throw_error(err).into() + } else { + JsValue::UnDefined + } + } + + fn js_compute_single( + &mut self, + _this_obj: &mut JsObject, + ctx: &mut Context, + _argv: &[JsValue], + ) -> JsValue { + if let Err(e) = self.ctx.compute_single() { + let err = ggml_error_to_js_error(ctx, e); + ctx.throw_error(err).into() + } else { + JsValue::UnDefined + } + } + + fn js_fini_single( + &mut self, + _this_obj: &mut JsObject, + ctx: &mut Context, + _argv: &[JsValue], + ) -> JsValue { + if let Err(e) = self.ctx.fini_single() { + let err = ggml_error_to_js_error(ctx, e); + ctx.throw_error(err).into() + } else { + JsValue::UnDefined + } + } + + fn js_get_output_single( + &mut self, + _this_obj: &mut JsObject, + ctx: &mut Context, + argv: &[JsValue], + ) -> JsValue { + let index = if let Some(JsValue::Int(index)) = argv.get(0) { + *index as usize + } else { + return ctx.throw_type_error("'index' must be of type int").into(); + }; + + let output_type = if let Some(JsValue::Int(type_index)) = argv.get(1) { + *type_index + } else { + return ctx + .throw_type_error("'output_type' must be of type Int") + .into(); + }; + + let mut output_buffer = vec![0u8; *MAX_OUTPUT_SIZE]; + + match self.ctx.get_output_single(index, output_buffer.as_mut()) { + Ok(n) => match output_type { + 0 => ctx.new_array_buffer(&output_buffer[0..n]).into(), + _ => ctx + .new_string(unsafe { std::str::from_utf8_unchecked(&output_buffer[0..n]) }) + .into(), + }, + Err(e) => { + let err = ggml_error_to_js_error(ctx, e); + ctx.throw_error(err).into() + } + } + } + + fn js_get_output( + &mut self, + _this_obj: &mut JsObject, + ctx: &mut Context, + argv: &[JsValue], + ) -> JsValue { + let index = if let Some(JsValue::Int(index)) = argv.get(0) { + *index as usize + } else { + return ctx.throw_type_error("'index' must be of type int").into(); + }; + + let mut output = if let Some(JsValue::ArrayBuffer(buf)) = argv.get(1) { + buf.clone() + } else { + return ctx + .throw_type_error("'output' must be of type buffer") + .into(); + }; + + match self.ctx.get_output(index, output.as_mut()) { + Ok(n) => JsValue::Int(n as i32), + Err(e) => { + let err = ggml_error_to_js_error(ctx, e); + ctx.throw_error(err).into() + } + } + } +} + +fn js_build_graph_from_cache(ctx: &mut Context, _this: JsValue, param: &[JsValue]) -> JsValue { + if let Some( + [JsValue::Int(target_index), JsValue::String(metadata), JsValue::String(module_name)], + ) = param.get(0..3) + { + let target = match *target_index { + 0 => wasi_nn::ExecutionTarget::CPU, + 1 => wasi_nn::ExecutionTarget::GPU, + 2 => wasi_nn::ExecutionTarget::TPU, + _ => wasi_nn::ExecutionTarget::AUTO, + }; + let config = wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Ggml, target) + .config(metadata.to_string()) + .build_from_cache(module_name.as_str()); + + match config { + Ok(g) => WasiNNGraph::wrap_obj(ctx, WasiNNGraph(g)), + Err(e) => { + let err = ggml_error_to_js_error(ctx, e); + ctx.throw_error(err).into() + } + } + } else { + JsValue::UnDefined + } +} + +pub fn init_wasi_nn_ggml_module(ctx: &mut Context) { + ctx.register_fn_module( + "_wasi_nn_ggml", + &[ + WasiNNGraph::CLASS_NAME, + WasiNNGraphExecutionContext::CLASS_NAME, + "build_graph_from_cache", + ], + |ctx, m| { + let class_ctor = register_class::(ctx); + m.add_export(WasiNNGraph::CLASS_NAME, class_ctor); + + let class_ctor = register_class::(ctx); + m.add_export(WasiNNGraphExecutionContext::CLASS_NAME, class_ctor); + + let f = ctx.wrap_function("build_graph_from_cache", js_build_graph_from_cache); + m.add_export("build_graph_from_cache", f.into()); + }, + ) +} + +struct GGMLChatPromptTemplate { + prompt: ChatPrompt, +} + +fn create_prompt_template(template_ty: PromptTemplateType) -> ChatPrompt { + match template_ty { + PromptTemplateType::Llama2Chat => { + ChatPrompt::Llama2ChatPrompt(chat_prompts::chat::llama::Llama2ChatPrompt::default()) + } + PromptTemplateType::MistralInstruct => ChatPrompt::MistralInstructPrompt( + chat_prompts::chat::mistral::MistralInstructPrompt::default(), + ), + PromptTemplateType::MistralLite => { + ChatPrompt::MistralLitePrompt(chat_prompts::chat::mistral::MistralLitePrompt::default()) + } + PromptTemplateType::OpenChat => { + ChatPrompt::OpenChatPrompt(chat_prompts::chat::openchat::OpenChatPrompt::default()) + } + PromptTemplateType::CodeLlama => ChatPrompt::CodeLlamaInstructPrompt( + chat_prompts::chat::llama::CodeLlamaInstructPrompt::default(), + ), + PromptTemplateType::BelleLlama2Chat => ChatPrompt::BelleLlama2ChatPrompt( + chat_prompts::chat::belle::BelleLlama2ChatPrompt::default(), + ), + PromptTemplateType::VicunaChat => { + ChatPrompt::VicunaChatPrompt(chat_prompts::chat::vicuna::VicunaChatPrompt::default()) + } + PromptTemplateType::Vicuna11Chat => { + ChatPrompt::Vicuna11ChatPrompt(chat_prompts::chat::vicuna::Vicuna11ChatPrompt::default()) + } + PromptTemplateType::ChatML => { + ChatPrompt::ChatMLPrompt(chat_prompts::chat::chatml::ChatMLPrompt::default()) + } + PromptTemplateType::Baichuan2 => ChatPrompt::Baichuan2ChatPrompt( + chat_prompts::chat::baichuan::Baichuan2ChatPrompt::default(), + ), + PromptTemplateType::WizardCoder => { + ChatPrompt::WizardCoderPrompt(chat_prompts::chat::wizard::WizardCoderPrompt::default()) + } + PromptTemplateType::Zephyr => { + ChatPrompt::ZephyrChatPrompt(chat_prompts::chat::zephyr::ZephyrChatPrompt::default()) + } + PromptTemplateType::IntelNeural => { + ChatPrompt::NeuralChatPrompt(chat_prompts::chat::intel::NeuralChatPrompt::default()) + } + PromptTemplateType::DeepseekChat => ChatPrompt::DeepseekChatPrompt( + chat_prompts::chat::deepseek::DeepseekChatPrompt::default(), + ), + PromptTemplateType::DeepseekCoder => ChatPrompt::DeepseekCoderPrompt( + chat_prompts::chat::deepseek::DeepseekCoderPrompt::default(), + ), + PromptTemplateType::SolarInstruct => ChatPrompt::SolarInstructPrompt( + chat_prompts::chat::solar::SolarInstructPrompt::default(), + ), + } +} + +impl GGMLChatPromptTemplate { + fn js_build( + &mut self, + _this_obj: &mut JsObject, + ctx: &mut Context, + argv: &[JsValue], + ) -> JsValue { + if let Some(JsValue::Object(js_obj)) = argv.first() { + let mut js_obj = js_obj.clone().into(); + if let Some(req) = GGMLChatCompletionRequest::opaque_mut(&mut js_obj) { + return match self.prompt.build(&mut req.req.messages) { + Ok(s) => ctx.new_string(s.as_str()).into(), + Err(e) => { + let error = ctx.new_error(e.to_string().as_str()); + ctx.throw_error(error).into() + } + }; + } + } + ctx.throw_type_error("'request' must be of type GGMLChatCompletionRequest") + .into() + } +} + +impl JsClassDef for GGMLChatPromptTemplate { + type RefType = GGMLChatPromptTemplate; + + const CLASS_NAME: &'static str = "GGMLChatPrompt"; + + const CONSTRUCTOR_ARGC: u8 = 1; + + const FIELDS: &'static [crate::JsClassField] = &[]; + + const METHODS: &'static [crate::JsClassMethod] = &[("build", 1, Self::js_build)]; + + unsafe fn mut_class_id_ptr() -> &'static mut u32 { + static mut CLASS_ID: u32 = 0; + &mut CLASS_ID + } + + fn constructor_fn(ctx: &mut Context, argv: &[JsValue]) -> Result { + if let Some(JsValue::String(type_str)) = argv.first() { + match PromptTemplateType::from_str(type_str.as_str()) { + Ok(template_ty) => Ok(Self { + prompt: create_prompt_template(template_ty), + }), + Err(_) => Err(JsValue::UnDefined), + } + } else { + Err(ctx + .throw_type_error("'tensor_buf' must be of type buffer or string") + .into()) + } + } +} + +struct GGMLChatCompletionRequest { + req: ChatCompletionRequest, +} + +impl GGMLChatCompletionRequest { + fn js_push_message( + &mut self, + _this_obj: &mut JsObject, + ctx: &mut Context, + argv: &[JsValue], + ) -> JsValue { + if let Some([JsValue::String(role), JsValue::String(content)]) = argv.get(0..2) { + let role = + match role.as_str() { + "system" => ChatCompletionRole::System, + "user" => ChatCompletionRole::User, + "function" => ChatCompletionRole::Function, + "assistant" => ChatCompletionRole::Assistant, + _ => return ctx + .throw_type_error( + "`role` must be either `system`, `user`, `assistant`, or `function`.", + ) + .into(), + }; + self.req + .messages + .push(ChatCompletionRequestMessage::new(role, content.as_str())); + JsValue::UnDefined + } else { + JsValue::UnDefined + } + } +} + +impl JsClassDef for GGMLChatCompletionRequest { + type RefType = GGMLChatCompletionRequest; + + const CLASS_NAME: &'static str = "GGMLChatCompletionRequest"; + + const CONSTRUCTOR_ARGC: u8 = 0; + + const FIELDS: &'static [crate::JsClassField] = &[]; + + const METHODS: &'static [crate::JsClassMethod] = + &[("push_message", 2, Self::js_push_message)]; + + unsafe fn mut_class_id_ptr() -> &'static mut u32 { + static mut CLASS_ID: u32 = 0; + &mut CLASS_ID + } + + fn constructor_fn(_ctx: &mut Context, _argv: &[JsValue]) -> Result { + Ok(Self { + req: ChatCompletionRequest::default(), + }) + } +} + +pub fn init_ggml_template_module(ctx: &mut Context) { + ctx.register_fn_module( + "_wasi_nn_ggml_template", + &[ + GGMLChatCompletionRequest::CLASS_NAME, + GGMLChatPromptTemplate::CLASS_NAME, + ], + |ctx, m| { + let class_ctor = register_class::(ctx); + m.add_export(GGMLChatCompletionRequest::CLASS_NAME, class_ctor); + + let class_ctor = register_class::(ctx); + m.add_export(GGMLChatPromptTemplate::CLASS_NAME, class_ctor); + }, + ) +} diff --git a/src/internal_module/httpx/js_module.rs b/src/internal_module/httpx/js_module.rs index 97ab5a8..f132d7e 100644 --- a/src/internal_module/httpx/js_module.rs +++ b/src/internal_module/httpx/js_module.rs @@ -225,7 +225,7 @@ impl HttpRequest { } pub fn js_get_method(&self, ctx: &mut Context) -> JsValue { - ctx.new_string(&format!("{:?}", self.method)).into() + ctx.new_string(self.method.to_string().as_str()).into() } pub fn js_set_method(&mut self, _ctx: &mut Context, val: JsValue) { diff --git a/src/internal_module/mod.rs b/src/internal_module/mod.rs index ae739a8..9c3cde0 100644 --- a/src/internal_module/mod.rs +++ b/src/internal_module/mod.rs @@ -3,6 +3,8 @@ pub mod core; pub mod crypto; pub mod encoding; pub mod fs; +#[cfg(feature = "ggml")] +pub mod ggml; pub mod httpx; #[cfg(feature = "img")] pub mod img_module; diff --git a/src/quickjs_sys/js_class.rs b/src/quickjs_sys/js_class.rs index f4b2509..0d0c2f0 100644 --- a/src/quickjs_sys/js_class.rs +++ b/src/quickjs_sys/js_class.rs @@ -128,7 +128,7 @@ fn into_proto_function_list(p: JsClassProto) -> &'static [JSCFu u: JSCFunctionListEntry__bindgen_ty_1 { func: JSCFunctionListEntry__bindgen_ty_1__bindgen_ty_1 { length: argc, - cproto: JS_CFUNC_generic_magic as u8, + cproto: JSCFunctionEnum_JS_CFUNC_generic_magic as u8, cfunc: JSCFunctionType { generic_magic: Some(js_method_magic_trampoline::), }, @@ -142,6 +142,26 @@ fn into_proto_function_list(p: JsClassProto) -> &'static [JSCFu Vec::leak(entry_vec) } +pub struct SelfRefJsValue { + data: T, + val: JsValue, + _p: std::marker::PhantomData<(R, T)>, +} + +impl Deref for SelfRefJsValue { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.data + } +} + +impl DerefMut for SelfRefJsValue { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.data + } +} + pub trait JsClassTool: JsClassDef { fn class_id() -> u32; @@ -153,6 +173,28 @@ pub trait JsClassTool: JsClassDef { ctx.get_class_constructor(Self::class_id()) } + fn self_ref_opaque_mut( + js_obj: JsValue, + f: impl FnOnce(&'static Self::RefType) -> Result, + ) -> Option, E>> + where + Self: Sized, + { + unsafe { + let class_id = Self::class_id(); + let ptr = JS_GetOpaque(js_obj.get_qjs_value(), class_id) as *mut Self::RefType; + let r: &'static mut ::RefType = ptr.as_mut()?; + match f(r) { + Ok(data) => Some(Ok(SelfRefJsValue { + data, + val: js_obj, + _p: Default::default(), + })), + Err(e) => Some(Err(e)), + } + } + } + fn opaque_mut(js_obj: &mut JsValue) -> Option<&mut Self::RefType> { unsafe { let class_id = Self::class_id(); @@ -563,7 +605,7 @@ pub fn register_class(ctx: &mut Context) -> JsValue { Some(constructor::), class_name.as_ptr().cast(), Def::CONSTRUCTOR_ARGC as i32, - JS_CFUNC_constructor, + JSCFunctionEnum_JS_CFUNC_constructor, 0, ); diff --git a/src/quickjs_sys/mod.rs b/src/quickjs_sys/mod.rs index dabe6fe..6fa1004 100644 --- a/src/quickjs_sys/mod.rs +++ b/src/quickjs_sys/mod.rs @@ -403,6 +403,11 @@ impl Context { super::internal_module::crypto::init_module(&mut ctx); } + #[cfg(feature = "ggml")] + { + super::internal_module::ggml::init_wasi_nn_ggml_module(&mut ctx); + super::internal_module::ggml::init_ggml_template_module(&mut ctx); + } ctx } @@ -606,9 +611,21 @@ impl Context { } } - #[deprecated] pub fn promise_loop_poll(&mut self) { - todo!() + unsafe { + let rt = self.rt(); + let mut pctx: *mut JSContext = 0 as *mut JSContext; + + loop { + let err = JS_ExecutePendingJob(rt, (&mut pctx) as *mut *mut JSContext); + if err <= 0 { + if err < 0 { + js_std_dump_error(pctx); + } + break; + } + } + } } #[deprecated] @@ -634,7 +651,7 @@ impl Clone for Context { } unsafe fn to_u32(ctx: *mut JSContext, v: JSValue) -> Result { - if JS_VALUE_GET_NORM_TAG_real(v) == JS_TAG_INT { + if JS_VALUE_GET_NORM_TAG_real(v) == JS_TAG_JS_TAG_INT { let mut r = 0u32; JS_ToUint32_real(ctx, &mut r as *mut u32, v); Ok(r) @@ -691,13 +708,13 @@ impl Drop for JsRef { unsafe { let tag = JS_VALUE_GET_NORM_TAG_real(self.v); match tag { - JS_TAG_STRING - | JS_TAG_OBJECT - | JS_TAG_FUNCTION_BYTECODE - | JS_TAG_BIG_INT - | JS_TAG_BIG_FLOAT - | JS_TAG_BIG_DECIMAL - | JS_TAG_SYMBOL => JS_FreeValue_real(self.ctx, self.v), + JS_TAG_JS_TAG_STRING + | JS_TAG_JS_TAG_OBJECT + | JS_TAG_JS_TAG_FUNCTION_BYTECODE + | JS_TAG_JS_TAG_BIG_INT + | JS_TAG_JS_TAG_BIG_FLOAT + | JS_TAG_JS_TAG_BIG_DECIMAL + | JS_TAG_JS_TAG_SYMBOL => JS_FreeValue_real(self.ctx, self.v), _ => {} } } @@ -879,7 +896,7 @@ impl JsArray { let mut values = Vec::new(); for index in 0..(len as usize) { let value_raw = JS_GetPropertyUint32(ctx, v, index as u32); - if JS_VALUE_GET_NORM_TAG_real(value_raw) == JS_TAG_EXCEPTION { + if JS_VALUE_GET_NORM_TAG_real(value_raw) == JS_TAG_JS_TAG_EXCEPTION { return Err(JsException(JsRef { ctx, v: value_raw })); } let v = JsValue::from_qjs_value(ctx, value_raw); @@ -1066,22 +1083,22 @@ impl JsValue { unsafe { let tag = JS_VALUE_GET_NORM_TAG_real(v); match tag { - JS_TAG_INT => { + JS_TAG_JS_TAG_INT => { let mut num = 0; JS_ToInt32(ctx, (&mut num) as *mut i32, v); JsValue::Int(num) } - JS_TAG_FLOAT64 => { + JS_TAG_JS_TAG_FLOAT64 => { let mut num = 0_f64; JS_ToFloat64(ctx, (&mut num) as *mut f64, v); JsValue::Float(num) } - JS_TAG_BIG_DECIMAL | JS_TAG_BIG_INT | JS_TAG_BIG_FLOAT => { + JS_TAG_JS_TAG_BIG_DECIMAL | JS_TAG_JS_TAG_BIG_INT | JS_TAG_JS_TAG_BIG_FLOAT => { JsValue::BigNum(JsBigNum(JsRef { ctx, v })) } - JS_TAG_STRING => JsValue::String(JsString(JsRef { ctx, v })), - JS_TAG_MODULE => JsValue::Module(JsModule(JsRef { ctx, v })), - JS_TAG_OBJECT => { + JS_TAG_JS_TAG_STRING => JsValue::String(JsString(JsRef { ctx, v })), + JS_TAG_JS_TAG_MODULE => JsValue::Module(JsModule(JsRef { ctx, v })), + JS_TAG_JS_TAG_OBJECT => { if JS_IsFunction(ctx, v) != 0 { JsValue::Function(JsFunction(JsRef { ctx, v })) } else if JS_IsArrayBuffer(ctx, v) != 0 { @@ -1094,14 +1111,14 @@ impl JsValue { JsValue::Object(JsObject(JsRef { ctx, v })) } } - JS_TAG_BOOL => JsValue::Bool(JS_ToBool(ctx, v) != 0), - JS_TAG_NULL => JsValue::Null, - JS_TAG_EXCEPTION => JsValue::Exception(JsException(JsRef { ctx, v })), - JS_TAG_UNDEFINED => JsValue::UnDefined, - JS_TAG_FUNCTION_BYTECODE => { + JS_TAG_JS_TAG_BOOL => JsValue::Bool(JS_ToBool(ctx, v) != 0), + JS_TAG_JS_TAG_NULL => JsValue::Null, + JS_TAG_JS_TAG_EXCEPTION => JsValue::Exception(JsException(JsRef { ctx, v })), + JS_TAG_JS_TAG_UNDEFINED => JsValue::UnDefined, + JS_TAG_JS_TAG_FUNCTION_BYTECODE => { JsValue::FunctionByteCode(JsFunctionByteCode(JsRef { ctx, v })) } - JS_TAG_SYMBOL => JsValue::Symbol(JsRef { ctx, v }), + JS_TAG_JS_TAG_SYMBOL => JsValue::Symbol(JsRef { ctx, v }), _ => JsValue::Other(JsRef { ctx, v }), } }