From 12f3d50e0b55079e9b347ff65be5414148ed33ee Mon Sep 17 00:00:00 2001 From: Wilson Wang Date: Wed, 1 Jan 2025 20:33:06 +0800 Subject: [PATCH] use flatbuffers --- .gitignore | 5 +- Makefile | 56 ++- README.md | 12 +- cmd/{worker => spearlet}/main.go | 26 +- go.mod | 1 + go.sum | 2 + pkg/common/const.go | 13 +- pkg/net/client.go | 407 +++++++++++++++++ pkg/net/http.go | 4 - pkg/rpc/client.go | 263 ----------- pkg/rpc/json.go | 188 -------- pkg/rpc/payload/chat.go | 79 ---- pkg/rpc/payload/const.go | 24 - pkg/rpc/payload/msgpassing.go | 101 ----- pkg/rpc/payload/tools.go | 87 ---- pkg/rpc/payload/transform.go | 89 ---- pkg/rpc/payload/transform/const.go | 5 - pkg/rpc/payload/transform/embeddings.go | 55 --- pkg/rpc/payload/transform/image_gen.go | 36 -- pkg/rpc/payload/transform/speech2text.go | 28 -- pkg/rpc/payload/transform/text2speech.go | 30 -- pkg/rpc/payload/vectorstore.go | 102 ----- pkg/rpc/transform.go | 150 ------- pkg/tools/docker/docker.go | 10 +- pkg/utils/protohelper/helper.go | 141 ++++++ proto/chat/common.fbs | 26 ++ proto/chat/request.fbs | 17 + proto/chat/response.fbs | 11 + proto/custom/request.fbs | 8 + proto/custom/response.fbs | 7 + proto/io/input-req.fbs | 9 + proto/io/input-resp.fbs | 8 + proto/io/record-req.fbs | 8 + proto/io/record-resp.fbs | 7 + proto/io/speak-req.fbs | 11 + proto/io/speak-resp.fbs | 6 + proto/tool/common.fbs | 47 ++ proto/tool/internal-req.fbs | 18 + proto/tool/internal-resp.fbs | 9 + proto/tool/invoke-req.fbs | 16 + proto/tool/invoke-resp.fbs | 7 + proto/transform/common.fbs | 24 + proto/transform/request.fbs | 23 + proto/transform/response.fbs | 20 + proto/transport/request.fbs | 31 ++ proto/transport/response.fbs | 12 + proto/transport/signal.fbs | 15 + proto/transport/transport.fbs | 17 + sdk/cpp/Makefile | 18 + sdk/python/Makefile | 22 +- sdk/python/project.toml | 2 +- sdk/python/requirements.txt | 4 + sdk/python/setup.py | 10 +- sdk/python/spear/client.py | 421 ++++++++---------- sdk/python/spear/transform/__init__.py | 0 sdk/python/spear/transform/chat.py | 130 ++++++ sdk/python/spear/utils/io.py | 94 +++- sdk/python/spear/utils/tool.py | 95 ++++ sdk/python/tests/__init__.py | 0 sdk/python/tests/proto/__init__.py | 0 sdk/python/tests/proto/server.py | 219 +++++++++ sdk/python/tests/proto/test_chat.py | 57 +++ spearlet/hostcalls/chat.go | 419 +++++++++++++++++ spearlet/hostcalls/common/common.go | 321 +++++++++++++ .../hostcalls/common/models.go | 36 +- spearlet/hostcalls/common/tools.go | 133 ++++++ spearlet/hostcalls/embeddings.go | 39 ++ spearlet/hostcalls/gen_image.go | 50 +++ spearlet/hostcalls/hc_entries.go | 100 +++++ .../hostcalls/huggingface/huggingface_hc.go | 84 ++++ .../hostcalls/io.go | 128 ++++-- spearlet/hostcalls/msgpassing.go | 143 ++++++ .../hostcalls/openai/openai_hc.go | 115 +++-- spearlet/hostcalls/stt.go | 57 +++ spearlet/hostcalls/tools.go | 51 +++ spearlet/hostcalls/transform.go | 168 +++++++ spearlet/hostcalls/tts.go | 56 +++ spearlet/hostcalls/vectorstore.go | 299 +++++++++++++ worker/worker.go => spearlet/spearlet.go | 140 +++--- {worker => spearlet}/task/const.go | 0 {worker => spearlet}/task/docker.go | 32 +- {worker => spearlet}/task/docker/utils.go | 0 {worker => spearlet}/task/docker_rt.go | 7 +- {worker => spearlet}/task/proc.go | 10 +- {worker => spearlet}/task/proc_rt.go | 0 {worker => spearlet}/task/task.go | 8 +- {worker => spearlet}/tools/contact.go | 5 +- {worker => spearlet}/tools/datetime.go | 8 +- {worker => spearlet}/tools/email.go | 15 +- {worker => spearlet}/tools/mouse.go | 8 +- {worker => spearlet}/tools/phone.go | 5 +- {worker => spearlet}/tools/screen.go | 5 +- {worker => spearlet}/tools/web.go | 20 +- {worker => spearlet}/types.go | 2 +- test/functionality_test.go | 25 ++ test/simple_req_test.go | 119 +---- test/test_guide.md | 2 +- worker/hostcalls/chat.go | 385 ---------------- worker/hostcalls/common/common.go | 271 ----------- worker/hostcalls/common/tools.go | 43 -- worker/hostcalls/embeddings.go | 39 -- worker/hostcalls/gen_image.go | 52 --- worker/hostcalls/hc_entries.go | 98 ---- .../hostcalls/huggingface/huggingface_hc.go | 88 ---- worker/hostcalls/msgpassing.go | 140 ------ worker/hostcalls/stt.go | 40 -- worker/hostcalls/tools.go | 186 -------- worker/hostcalls/transform.go | 189 -------- worker/hostcalls/tts.go | 37 -- worker/hostcalls/utils.go | 37 -- worker/hostcalls/vectorstore.go | 293 ------------ workload/docker/go/dummy/Dockerfile | 12 - workload/docker/go/dummy/Makefile | 14 - workload/docker/go/dummy/compose.yaml | 8 - workload/docker/go/dummy/scripts/start.sh | 9 - workload/docker/go/dummy/src/start.go | 158 ------- workload/docker/go/gen_image/Makefile | 10 +- workload/docker/go/gen_image/compose.yaml | 4 +- workload/docker/go/gen_image/src/demo.go | 122 ----- workload/docker/go/gen_image/src/start.go | 55 ++- workload/docker/go/voice_chat/Dockerfile | 8 - workload/docker/go/voice_chat/Makefile | 23 - workload/docker/go/voice_chat/compose.yaml | 6 - workload/docker/go/voice_chat/src/demo.go | 137 ------ workload/docker/go/voice_chat/src/start.go | 118 ----- workload/docker/python/pychat/Dockerfile | 2 +- workload/docker/python/pychat/Makefile | 2 +- workload/docker/python/pychat/compose.yaml | 4 +- .../python/pyconversation-local/Dockerfile | 6 +- .../python/pyconversation-local/Makefile | 2 +- .../python/pyconversation-local/compose.yaml | 4 +- workload/docker/python/pydummy/Dockerfile | 6 +- workload/docker/python/pydummy/Makefile | 2 +- workload/docker/python/pydummy/compose.yaml | 2 +- workload/docker/python/pydummy/src/start.py | 34 +- .../python/pytest-functionality/Dockerfile | 19 + .../python/pytest-functionality/Makefile | 13 + .../python/pytest-functionality/compose.yaml | 6 + .../python/pytest-functionality/src/start.py | 120 +++++ workload/docker/python/pytools/Dockerfile | 2 +- workload/docker/python/pytools/Makefile | 2 +- workload/docker/python/pytools/compose.yaml | 4 +- workload/process/dummy/Makefile | 11 - workload/process/dummy/main.go | 101 ----- 144 files changed, 4323 insertions(+), 4554 deletions(-) rename cmd/{worker => spearlet}/main.go (87%) create mode 100644 pkg/net/client.go delete mode 100644 pkg/rpc/client.go delete mode 100644 pkg/rpc/json.go delete mode 100644 pkg/rpc/payload/chat.go delete mode 100644 pkg/rpc/payload/const.go delete mode 100644 pkg/rpc/payload/msgpassing.go delete mode 100644 pkg/rpc/payload/tools.go delete mode 100644 pkg/rpc/payload/transform.go delete mode 100644 pkg/rpc/payload/transform/const.go delete mode 100644 pkg/rpc/payload/transform/embeddings.go delete mode 100644 pkg/rpc/payload/transform/image_gen.go delete mode 100644 pkg/rpc/payload/transform/speech2text.go delete mode 100644 pkg/rpc/payload/transform/text2speech.go delete mode 100644 pkg/rpc/payload/vectorstore.go delete mode 100644 pkg/rpc/transform.go create mode 100644 pkg/utils/protohelper/helper.go create mode 100644 proto/chat/common.fbs create mode 100644 proto/chat/request.fbs create mode 100644 proto/chat/response.fbs create mode 100644 proto/custom/request.fbs create mode 100644 proto/custom/response.fbs create mode 100644 proto/io/input-req.fbs create mode 100644 proto/io/input-resp.fbs create mode 100644 proto/io/record-req.fbs create mode 100644 proto/io/record-resp.fbs create mode 100644 proto/io/speak-req.fbs create mode 100644 proto/io/speak-resp.fbs create mode 100644 proto/tool/common.fbs create mode 100644 proto/tool/internal-req.fbs create mode 100644 proto/tool/internal-resp.fbs create mode 100644 proto/tool/invoke-req.fbs create mode 100644 proto/tool/invoke-resp.fbs create mode 100644 proto/transform/common.fbs create mode 100644 proto/transform/request.fbs create mode 100644 proto/transform/response.fbs create mode 100644 proto/transport/request.fbs create mode 100644 proto/transport/response.fbs create mode 100644 proto/transport/signal.fbs create mode 100644 proto/transport/transport.fbs create mode 100644 sdk/cpp/Makefile create mode 100644 sdk/python/requirements.txt create mode 100644 sdk/python/spear/transform/__init__.py create mode 100644 sdk/python/spear/transform/chat.py create mode 100644 sdk/python/spear/utils/tool.py create mode 100644 sdk/python/tests/__init__.py create mode 100644 sdk/python/tests/proto/__init__.py create mode 100644 sdk/python/tests/proto/server.py create mode 100644 sdk/python/tests/proto/test_chat.py create mode 100644 spearlet/hostcalls/chat.go create mode 100644 spearlet/hostcalls/common/common.go rename {worker => spearlet}/hostcalls/common/models.go (87%) create mode 100644 spearlet/hostcalls/common/tools.go create mode 100644 spearlet/hostcalls/embeddings.go create mode 100644 spearlet/hostcalls/gen_image.go create mode 100644 spearlet/hostcalls/hc_entries.go create mode 100644 spearlet/hostcalls/huggingface/huggingface_hc.go rename worker/hostcalls/local_utils.go => spearlet/hostcalls/io.go (53%) create mode 100644 spearlet/hostcalls/msgpassing.go rename {worker => spearlet}/hostcalls/openai/openai_hc.go (82%) create mode 100644 spearlet/hostcalls/stt.go create mode 100644 spearlet/hostcalls/tools.go create mode 100644 spearlet/hostcalls/transform.go create mode 100644 spearlet/hostcalls/tts.go create mode 100644 spearlet/hostcalls/vectorstore.go rename worker/worker.go => spearlet/spearlet.go (65%) rename {worker => spearlet}/task/const.go (100%) rename {worker => spearlet}/task/docker.go (91%) rename {worker => spearlet}/task/docker/utils.go (100%) rename {worker => spearlet}/task/docker_rt.go (96%) rename {worker => spearlet}/task/proc.go (94%) rename {worker => spearlet}/task/proc_rt.go (100%) rename {worker => spearlet}/task/task.go (98%) rename {worker => spearlet}/tools/contact.go (90%) rename {worker => spearlet}/tools/datetime.go (81%) rename {worker => spearlet}/tools/email.go (92%) rename {worker => spearlet}/tools/mouse.go (79%) rename {worker => spearlet}/tools/phone.go (90%) rename {worker => spearlet}/tools/screen.go (87%) rename {worker => spearlet}/tools/web.go (87%) rename {worker => spearlet}/types.go (87%) create mode 100644 test/functionality_test.go delete mode 100644 worker/hostcalls/chat.go delete mode 100644 worker/hostcalls/common/common.go delete mode 100644 worker/hostcalls/common/tools.go delete mode 100644 worker/hostcalls/embeddings.go delete mode 100644 worker/hostcalls/gen_image.go delete mode 100644 worker/hostcalls/hc_entries.go delete mode 100644 worker/hostcalls/huggingface/huggingface_hc.go delete mode 100644 worker/hostcalls/msgpassing.go delete mode 100644 worker/hostcalls/stt.go delete mode 100644 worker/hostcalls/tools.go delete mode 100644 worker/hostcalls/transform.go delete mode 100644 worker/hostcalls/tts.go delete mode 100644 worker/hostcalls/utils.go delete mode 100644 worker/hostcalls/vectorstore.go delete mode 100644 workload/docker/go/dummy/Dockerfile delete mode 100644 workload/docker/go/dummy/Makefile delete mode 100644 workload/docker/go/dummy/compose.yaml delete mode 100755 workload/docker/go/dummy/scripts/start.sh delete mode 100644 workload/docker/go/dummy/src/start.go delete mode 100644 workload/docker/go/gen_image/src/demo.go delete mode 100644 workload/docker/go/voice_chat/Dockerfile delete mode 100644 workload/docker/go/voice_chat/Makefile delete mode 100644 workload/docker/go/voice_chat/compose.yaml delete mode 100644 workload/docker/go/voice_chat/src/demo.go delete mode 100644 workload/docker/go/voice_chat/src/start.go create mode 100644 workload/docker/python/pytest-functionality/Dockerfile create mode 100644 workload/docker/python/pytest-functionality/Makefile create mode 100644 workload/docker/python/pytest-functionality/compose.yaml create mode 100755 workload/docker/python/pytest-functionality/src/start.py delete mode 100644 workload/process/dummy/Makefile delete mode 100644 workload/process/dummy/main.go diff --git a/.gitignore b/.gitignore index ef6a9bd..44efb6c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,9 @@ bin bin/dummy_task -bin/worker +bin/spearlet .idea .DS_Store *.whl *.egg-info -sdk/python/dist \ No newline at end of file +sdk/python/dist +__pycache__ \ No newline at end of file diff --git a/Makefile b/Makefile index 081b37f..c50128e 100644 --- a/Makefile +++ b/Makefile @@ -1,34 +1,56 @@ -PROJECT_ROOT := $(shell pwd) -OUTPUT_DIR := $(PROJECT_ROOT)/bin +REPO_ROOT := $(shell pwd) +OUTPUT_DIR := $(REPO_ROOT)/bin -all: clean worker workload sdk +all: clean spearlet workload sdk + + +SUBDIRS := $(shell find $(REPO_ROOT) -mindepth 1 -maxdepth 3 -type d -exec test -e {}/Makefile \; -exec echo {} \;) +WORKLOAD_SUBDIRS := $(shell find $(REPO_ROOT)/workload -mindepth 1 -maxdepth 3 -type d -exec test -e {}/Makefile \; -exec echo {} \;) clean: + @set -ex; \ + docker system prune -f && \ rm -rf $(OUTPUT_DIR) && \ - find $(PROJECT_ROOT)/workload -mindepth 1 -maxdepth 3 -type d -exec test -e {}/Makefile \; -exec make -C {} clean \; - find $(PROJECT_ROOT)/sdk -mindepth 1 -maxdepth 2 -type d -exec test -e {}/Makefile \; -exec make -C {} clean \; - -worker: - go build -o $(OUTPUT_DIR)/worker \ - $(PROJECT_ROOT)/cmd/worker/main.go + rm -rf $(REPO_ROOT)/pkg/spear && \ + for dir in $(SUBDIRS); do \ + make -C $$dir clean; \ + done -test: workload - go test -v $(PROJECT_ROOT)/test/... +build: spearlet + @set -e; \ + for dir in $(SUBDIRS); do \ + make -C $$dir build; \ + done -workload: sdk - find $(PROJECT_ROOT)/workload -mindepth 1 -maxdepth 3 -type d -exec test -e {}/Makefile \; -exec echo "make -C {}" \; -exec make -C {} \; +spearlet: pkg/spear + go build -o $(OUTPUT_DIR)/spearlet \ + $(REPO_ROOT)/cmd/spearlet/main.go -sdk: - find $(PROJECT_ROOT)/sdk -mindepth 1 -maxdepth 2 -type d -exec test -e {}/Makefile \; -exec make -C {} \; +test: workload + @set -e; \ + go test -v $(REPO_ROOT)/test/... && \ + for dir in $(SUBDIRS); do \ + make -C $$dir test; \ + done + +workload: build + @set -e; \ + for dir in $(WORKLOAD_SUBDIRS); do \ + make -C $$dir; \ + done format_python: - isort -rc $(PROJECT_ROOT)/ + isort -rc $(REPO_ROOT)/ format_golang: gofmt -w . format: format_python format_golang -.PHONY: all worker test workload clean sdk format_python format +pkg/spear: + allfiles=`find ${REPO_ROOT}/proto -name "*.fbs"`; \ + flatc -o $(REPO_ROOT)/pkg/ -I ${REPO_ROOT}/proto --go-module-name "github.com/lfedgeai/spear/pkg" --go $${allfiles} + +.PHONY: all spearlet test workload clean format_python format diff --git a/README.md b/README.md index b7e36c4..f6bcdab 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ SPEAR is an advanced AI Agent platform designed to support multiple runtime envi ```bash python -m pip install --upgrade pip pip install build - apt install portaudio19-dev libx11-dev libxtst-dev + apt install portaudio19-dev libx11-dev libxtst-dev flatbuffers-compiler curl -fsSL https://get.docker.com -o get-docker.sh sh get-docker.sh ``` @@ -105,11 +105,11 @@ To run SPEAR in local mode, use the following command: export OPENAI_API_KEY= export HUGGINGFACEHUB_API_TOKEN= export SPEAR_RPC_ADDR= -bin/worker exec -n pyconversation +bin/spearlet exec -n pyconversation ``` This command will: - - Start the SPEAR worker process in local mode. + - Start the SPEAR spearlet process in local mode. - Run the AI agent workload with an ID of 6. (pyconversation-local) Also, you need to set the environment variable `OPENAI_API_KEY` to your OpenAI API key. In the future, we will support other LLM providers. @@ -122,7 +122,7 @@ Also, you need to set the environment variable `OPENAI_API_KEY` to your OpenAI A PortAudio is required for the audio processing component. To install PortAudio on MacOS, use the following command: ```bash - brew install portaudio + brew install portaudio flatbuffers ``` ### Build Instructions @@ -142,11 +142,11 @@ To run SPEAR in local mode, use the following command: ```bash export OPENAI_API_KEY= -bin/worker exec -n pyconversation +bin/spearlet exec -n pyconversation ``` This command will: - - Start the SPEAR worker process in local mode. + - Start the SPEAR spearlet process in local mode. - Run the AI agent workload with an ID of 6. (pyconversation-local) Also, you need to set the environment variable `OPENAI_API_KEY` to your OpenAI API key. In the future, we will support other LLM providers. diff --git a/cmd/worker/main.go b/cmd/spearlet/main.go similarity index 87% rename from cmd/worker/main.go rename to cmd/spearlet/main.go index dc699fb..561f223 100644 --- a/cmd/worker/main.go +++ b/cmd/spearlet/main.go @@ -2,15 +2,15 @@ package main import ( "github.com/lfedgeai/spear/pkg/common" - "github.com/lfedgeai/spear/worker" - "github.com/lfedgeai/spear/worker/task" + spearlet "github.com/lfedgeai/spear/spearlet" + "github.com/lfedgeai/spear/spearlet/task" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "os" ) -type WorkerConfig struct { +type SpearletConfig struct { Addr string Port string } @@ -27,8 +27,8 @@ var ( func NewRootCmd() *cobra.Command { var rootCmd = &cobra.Command{ - Use: "worker", - Short: "Worker is the command line tool for the worker", + Use: "spearlet", + Short: "spearlet is the command line tool for the SPEAR spearlet", Run: func(cmd *cobra.Command, args []string) { cmd.Help() }, @@ -65,12 +65,12 @@ func NewRootCmd() *cobra.Command { log.Infof("Executing workload %s with runtime type %v", execWorkloadName, rtType) // set log level if execVerbose { - worker.SetLogLevel(log.DebugLevel) + spearlet.SetLogLevel(log.DebugLevel) } // create config - config := worker.NewExecWorkerConfig(execDebug, execSpearAddr) - w := worker.NewWorker(config) + config := spearlet.NewExecSpearletConfig(execDebug, execSpearAddr) + w := spearlet.NewSpearlet(config) w.Initialize() // lookup task id @@ -110,7 +110,7 @@ func NewRootCmd() *cobra.Command { var serveCmd = &cobra.Command{ Use: "serve", - Short: "Start the worker server", + Short: "Start the spearlet server", Run: func(cmd *cobra.Command, args []string) { // parse flags addr, _ := cmd.Flags().GetString("addr") @@ -121,7 +121,7 @@ func NewRootCmd() *cobra.Command { // set log level if verbose { - worker.SetLogLevel(log.DebugLevel) + spearlet.SetLogLevel(log.DebugLevel) } if execSpearAddr == "" { @@ -129,9 +129,9 @@ func NewRootCmd() *cobra.Command { } // create config - config := worker.NewServeWorkerConfig(addr, port, paths, + config := spearlet.NewServeSpearletConfig(addr, port, paths, debug, execSpearAddr) - w := worker.NewWorker(config) + w := spearlet.NewSpearlet(config) w.Initialize() w.StartServer() }, @@ -144,7 +144,7 @@ func NewRootCmd() *cobra.Command { serveCmd.PersistentFlags().BoolP("verbose", "v", false, "verbose output") // search path serveCmd.PersistentFlags().StringArrayP("search-path", "L", []string{}, - "search path list for the worker") + "search path list for the spearlet") // debug flag serveCmd.PersistentFlags().BoolP("debug", "d", false, "debug mode") rootCmd.AddCommand(serveCmd) diff --git a/go.mod b/go.mod index 573e8fa..5e45fe6 100644 --- a/go.mod +++ b/go.mod @@ -82,6 +82,7 @@ require ( github.com/chromedp/chromedp v0.11.2 github.com/faiface/beep v1.1.0 github.com/go-vgo/robotgo v0.110.5 + github.com/google/flatbuffers v24.3.25+incompatible github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 github.com/hajimehoshi/go-mp3 v0.3.4 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 62faf0a..561ec6d 100644 --- a/go.sum +++ b/go.sum @@ -71,6 +71,8 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69 github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI= +github.com/google/flatbuffers v24.3.25+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/pkg/common/const.go b/pkg/common/const.go index 71d23de..bc03aca 100644 --- a/pkg/common/const.go +++ b/pkg/common/const.go @@ -1,4 +1,15 @@ package common +import "runtime" + const MaxDataResponseSize = 4096 * 1024 -const SpearPlatformAddress = "172.17.0.1" + +var SpearPlatformAddress string + +func init() { + SpearPlatformAddress = map[string]string{ + "darwin": "host.docker.internal", + "linux": "172.17.0.1", + "windows": "host.docker.internal", + }[runtime.GOOS] +} diff --git a/pkg/net/client.go b/pkg/net/client.go new file mode 100644 index 0000000..9caa303 --- /dev/null +++ b/pkg/net/client.go @@ -0,0 +1,407 @@ +package net + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + "sync" + "time" + + flatbuffers "github.com/google/flatbuffers/go" + "github.com/lfedgeai/spear/pkg/spear/proto/custom" + "github.com/lfedgeai/spear/pkg/spear/proto/transport" + log "github.com/sirupsen/logrus" +) + +type RequestHandler func(req *transport.TransportRequest) (*transport.TransportResponse, error) +type ResponseHandler func(resp *transport.TransportResponse) error +type CustomRequestHandler func(req *custom.CustomRequest) (*custom.CustomResponse, error) + +type GuestRPCManager struct { + reqHandler map[transport.Method]RequestHandler + customReqHandler map[string]CustomRequestHandler + pendingRequests map[int64]reqCallbackStruct + pendingRequestsMu sync.RWMutex + input io.Reader + output io.Writer + + globalIDCounter int64 +} + +type reqCallbackStruct struct { + cb ResponseHandler + timeStamp time.Time + autoClear bool +} + +const ( + ResponseTimeout = time.Minute * 10 // 10 minutes timeout for requests +) + +func RPCManagerSendRequest[T any](rpcMgr *GuestRPCManager, method transport.Method, + params []byte) (*T, error) { + resp, err := rpcMgr.SendRequest(method, params) + if err != nil { + return nil, err + } + // first marshal to json + jsonData, err := json.Marshal(resp) + if err != nil { + return nil, err + } + // then unmarshal to T + var resp2 T + err = json.Unmarshal(jsonData, &resp2) + if err != nil { + return nil, err + } + return &resp2, nil +} + +func NewGuestRPCManager() *GuestRPCManager { + res := &GuestRPCManager{ + reqHandler: make(map[transport.Method]RequestHandler), + customReqHandler: make(map[string]CustomRequestHandler), + pendingRequests: make(map[int64]reqCallbackStruct), + pendingRequestsMu: sync.RWMutex{}, + globalIDCounter: 1, + } + + res.reqHandler[transport.MethodCustom] = + func(req *transport.TransportRequest) (*transport.TransportResponse, error) { + data := req.RequestBytes() + customReq := custom.GetRootAsCustomRequest(data, 0) + if customReq == nil { + return nil, fmt.Errorf("error unmarshalling custom request") + } + if hdl, ok := res.customReqHandler[string(customReq.MethodStr())]; ok { + resp, err := hdl(customReq) + if err != nil { + return nil, err + } + builder := flatbuffers.NewBuilder(512) + respOff := builder.CreateByteVector(resp.DataBytes()) + + transport.TransportResponseStart(builder) + transport.TransportResponseAddId(builder, req.Id()) + transport.TransportResponseAddResponse(builder, respOff) + builder.Finish(transport.TransportResponseEnd(builder)) + + data := builder.FinishedBytes() + transResp := transport.GetRootAsTransportResponse(data, 0) + if transResp == nil { + return nil, fmt.Errorf("error unmarshalling response") + } + return transResp, nil + } + return nil, fmt.Errorf("no handler for custom method %s", + customReq.MethodStr()) + } + + return res +} + +func (g *GuestRPCManager) SetInput(i io.Reader) { + g.input = i +} + +func (g *GuestRPCManager) SetOutput(o io.Writer) { + g.output = o +} + +func (g *GuestRPCManager) SetRequestCallback(id int64, callback ResponseHandler, + autoClear bool) { + g.pendingRequestsMu.Lock() + defer g.pendingRequestsMu.Unlock() + g.pendingRequests[id] = reqCallbackStruct{ + cb: callback, + timeStamp: time.Now(), + autoClear: autoClear, + } +} + +func (g *GuestRPCManager) ClearRequestCallback(id int64) { + g.pendingRequestsMu.Lock() + defer g.pendingRequestsMu.Unlock() + delete(g.pendingRequests, id) +} + +func (g *GuestRPCManager) RegisterIncomingCustomRequestHandler(method string, + handler CustomRequestHandler) error { + if _, ok := g.customReqHandler[method]; ok { + return fmt.Errorf("handler already registered for method %s", method) + } + g.customReqHandler[method] = handler + return nil +} + +func (g *GuestRPCManager) RegisterIncomingRequestHandler(method transport.Method, + handler RequestHandler) error { + if method == transport.MethodCustom { + return fmt.Errorf("cannot register handler for custom method") + } + if _, ok := g.reqHandler[method]; ok { + return fmt.Errorf("handler already registered for method %s", method) + } + g.reqHandler[method] = handler + return nil +} + +// high level function to send a request +func (g *GuestRPCManager) SendRequest(method transport.Method, + params []byte) ([]byte, error) { + builder := flatbuffers.NewBuilder(512) + paramOff := builder.CreateByteVector(params) + + transport.TransportRequestStart(builder) + transport.TransportRequestAddMethod(builder, method) + transport.TransportRequestAddRequest(builder, paramOff) + builder.Finish(transport.TransportRequestEnd(builder)) + + data := builder.FinishedBytes() + req := transport.GetRootAsTransportRequest(data, 0) + if req == nil { + return nil, fmt.Errorf("error unmarshalling request") + } + if resp, err := g.SendTransportRequest(req); err != nil { + return nil, err + } else { + return resp.ResponseBytes(), nil + } +} + +// low level function to send a json request +func (g *GuestRPCManager) SendTransportRequest( + req *transport.TransportRequest) (*transport.TransportResponse, error) { + if g.output == nil { + return nil, fmt.Errorf("output file not set") + } + builder := flatbuffers.NewBuilder(512) + reqOffset := builder.CreateByteVector(req.RequestBytes()) + + transport.TransportRequestStart(builder) + transport.TransportRequestAddId(builder, g.globalIDCounter) + defer func() { + g.globalIDCounter++ + }() + transport.TransportRequestAddMethod(builder, req.Method()) + transport.TransportRequestAddRequest(builder, reqOffset) + off := transport.TransportRequestEnd(builder) + + transport.TransportMessageRawStart(builder) + transport.TransportMessageRawAddDataType(builder, + transport.TransportMessageRaw_DataTransportRequest) + transport.TransportMessageRawAddData(builder, off) + builder.Finish(transport.TransportMessageRawEnd(builder)) + + data := builder.FinishedBytes() + dataLen := uint64(len(data)) + + ch := make(chan *transport.TransportResponse, 1) + g.SetRequestCallback(g.globalIDCounter, + func(resp *transport.TransportResponse) error { + ch <- resp + return nil + }, true) + + // write data length + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, dataLen) + if _, err := g.output.Write(buf); err != nil { + return nil, err + } + if _, err := g.output.Write(data); err != nil { + return nil, err + } + + // wait for response + select { + case <-time.After(ResponseTimeout): + return nil, fmt.Errorf("timeout waiting for response") + case resp := <-ch: + return resp, nil + } + +} + +func (g *GuestRPCManager) SendTransportResponse(id int64, + resp *transport.TransportResponse) error { + if g.output == nil { + return fmt.Errorf("output file not set") + } + builder := flatbuffers.NewBuilder(512) + respOffset := builder.CreateByteVector(resp.ResponseBytes()) + + transport.TransportResponseStart(builder) + transport.TransportResponseAddId(builder, id) + transport.TransportResponseAddResponse(builder, respOffset) + transportOff := transport.TransportResponseEnd(builder) + + transport.TransportMessageRawStart(builder) + transport.TransportMessageRawAddDataType(builder, + transport.TransportMessageRaw_DataTransportResponse) + transport.TransportMessageRawAddData(builder, transportOff) + builder.Finish(transport.TransportMessageRawEnd(builder)) + + data := builder.FinishedBytes() + dataLen := uint64(len(data)) + + // write data length + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, dataLen) + if _, err := g.output.Write(buf); err != nil { + return err + } + if _, err := g.output.Write(data); err != nil { + return err + } + return nil +} + +func (g *GuestRPCManager) sendErrorTransportResponse(id int64, + err error) error { + if g.output == nil { + return fmt.Errorf("output file not set") + } + builder := flatbuffers.NewBuilder(512) + errMsg := builder.CreateString(err.Error()) + + transport.TransportResponseStart(builder) + transport.TransportResponseAddId(builder, id) + transport.TransportResponseAddCode(builder, -1) + transport.TransportResponseAddMessage(builder, errMsg) + transportOff := transport.TransportResponseEnd(builder) + + transport.TransportMessageRawStart(builder) + transport.TransportMessageRawAddDataType(builder, + transport.TransportMessageRaw_DataTransportResponse) + transport.TransportMessageRawAddData(builder, transportOff) + builder.Finish(transport.TransportMessageRawEnd(builder)) + + data := builder.FinishedBytes() + dataLen := uint64(len(data)) + + // write data length + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, dataLen) + if _, err := g.output.Write(buf); err != nil { + return err + } + if _, err := g.output.Write(data); err != nil { + return err + } + return nil +} + +func (g *GuestRPCManager) Run() { + // read from stdin + reader := g.input + + for { + // read a 64 bit uint + buf := make([]byte, 8) + if _, err := reader.Read(buf); err != nil { + log.Errorf("Error reading from stdin: %v", err) + continue + } + dataLen := binary.LittleEndian.Uint64(buf) + + if dataLen == 0 { + log.Infof("Exiting") + break + } + + log.Debugf("Got message size: %d", dataLen) + // read dataLen bytes + data := make([]byte, dataLen) + if _, err := io.ReadFull(reader, data); err != nil { + log.Errorf("Error reading from stdin: %v", err) + continue + } + + if len(data) == 0 { + log.Infof("Exiting") + break + } + + req := transport.GetRootAsTransportMessageRaw(data, 0) + if req == nil { + log.Errorf("Error unmarshalling request") + break + } + + if req.DataType() == + transport.TransportMessageRaw_DataTransportRequest { + // request + transportReq := &transport.TransportRequest{} + tbl := flatbuffers.Table{} + if !req.Data(&tbl) { + log.Errorf("Error getting data from request") + break + } + transportReq.Init(tbl.Bytes, tbl.Pos) + if hdl, ok := g.reqHandler[transportReq.Method()]; ok { + go func() { + resp, err := hdl(transportReq) + if err != nil { + log.Errorf("Error handling request: %v", err) + if err := g.sendErrorTransportResponse(transportReq.Id(), + err); err != nil { + log.Errorf("Error sending error response: %v", err) + } + } else { + log.Debugf("Sending response for method %s", + transportReq.Method()) + if err := g.SendTransportResponse(transportReq.Id(), + resp); err != nil { + log.Errorf("Error sending response: %v", err) + } + } + }() + } + // TODO: handle request + } else if req.DataType() == + transport.TransportMessageRaw_DataTransportResponse { + // response + transportResp := &transport.TransportResponse{} + tbl := flatbuffers.Table{} + if !req.Data(&tbl) { + log.Errorf("Error getting data from response") + break + } + transportResp.Init(tbl.Bytes, tbl.Pos) + // check pending requests + g.pendingRequestsMu.RLock() + defer g.pendingRequestsMu.RUnlock() + callback, ok := g.pendingRequests[transportResp.Id()] + if ok { + go func() { + if err := callback.cb(transportResp); err != nil { + log.Errorf("Error handling response: %v", err) + } + if callback.autoClear { + g.ClearRequestCallback(transportResp.Id()) + } + }() + } else { + log.Errorf("No callback for response id %d", transportResp.Id()) + } + } else if req.DataType() == + transport.TransportMessageRaw_DataTransportSignal { + // signal + transportSig := &transport.TransportSignal{} + tbl := flatbuffers.Table{} + if !req.Data(&tbl) { + log.Errorf("Error getting data from signal") + break + } + transportSig.Init(tbl.Bytes, tbl.Pos) + log.Infof("Got signal: %s. But it is not supported yet.", + transportSig.Method().String()) + } else { + log.Errorf("Invalid data type: %v", req.DataType()) + break + } + } +} diff --git a/pkg/net/http.go b/pkg/net/http.go index f0b4567..81c36eb 100644 --- a/pkg/net/http.go +++ b/pkg/net/http.go @@ -5,8 +5,6 @@ import ( "fmt" "io" "net/http" - - log "github.com/sirupsen/logrus" ) type ContentType int @@ -24,8 +22,6 @@ func SendRequest(url string, data *bytes.Buffer, contentType interface{}, apiKey return nil, fmt.Errorf("error creating request: %v", err) } - log.Infof("Sending request to %s", url) - switch typ := contentType.(type) { case ContentType: switch typ { diff --git a/pkg/rpc/client.go b/pkg/rpc/client.go deleted file mode 100644 index d6d00b2..0000000 --- a/pkg/rpc/client.go +++ /dev/null @@ -1,263 +0,0 @@ -package rpc - -import ( - "encoding/binary" - "encoding/json" - "fmt" - "io" - "sync" - "time" - - log "github.com/sirupsen/logrus" -) - -type JsonRPCRequestHandler func(req *JsonRPCRequest) (*JsonRPCResponse, error) -type JsonRPCResponseHandler func(resp *JsonRPCResponse) error - -type RequestHandler func(args interface{}) (interface{}, error) - -type GuestRPCManager struct { - reqHandler map[string]JsonRPCRequestHandler - reqRawHandler JsonRPCRequestHandler - respRawHandler JsonRPCResponseHandler - input io.Reader - output io.Writer -} - -type ResquestCallback func(resp *JsonRPCResponse) error -type reqCallbackStruct struct { - cb ResquestCallback - timeStamp time.Time - autoClear bool -} - -var ( - pendingRequests = map[json.Number]reqCallbackStruct{} - pendingRequestsMu = sync.RWMutex{} - - globalIDCounter uint64 = 1 - - ResponseTimeout = time.Minute * 10 // 10 minutes timeout for requests -) - -func RPCManagerSendRequest[T any](rpcMgr *GuestRPCManager, method string, params interface{}) (*T, error) { - resp, err := rpcMgr.SendRequest(method, params) - if err != nil { - return nil, err - } - // first marshal to json - jsonData, err := json.Marshal(resp) - if err != nil { - return nil, err - } - // then unmarshal to T - var resp2 T - err = json.Unmarshal(jsonData, &resp2) - if err != nil { - return nil, err - } - return &resp2, nil -} - -func NewGuestRPCManager(reqHandler JsonRPCRequestHandler, respHandler JsonRPCResponseHandler) *GuestRPCManager { - return &GuestRPCManager{ - reqHandler: make(map[string]JsonRPCRequestHandler), - reqRawHandler: reqHandler, - respRawHandler: respHandler, - } -} - -func (g *GuestRPCManager) SetInput(i io.Reader) { - g.input = i -} - -func (g *GuestRPCManager) SetOutput(o io.Writer) { - g.output = o -} - -func (g *GuestRPCManager) SetRequestCallback(id json.Number, callback ResquestCallback, autoClear bool) { - pendingRequestsMu.Lock() - defer pendingRequestsMu.Unlock() - pendingRequests[id] = reqCallbackStruct{ - cb: callback, - timeStamp: time.Now(), - autoClear: autoClear, - } -} - -func (g *GuestRPCManager) ClearRequestCallback(id json.Number) { - pendingRequestsMu.Lock() - defer pendingRequestsMu.Unlock() - delete(pendingRequests, id) -} - -func (g *GuestRPCManager) RegisterIncomingHandler(method string, handler RequestHandler) error { - if _, ok := g.reqHandler[method]; ok { - return fmt.Errorf("handler already registered for method %s", method) - } - g.reqHandler[method] = func(req *JsonRPCRequest) (*JsonRPCResponse, error) { - params := req.Params - result, err := handler(params) - if err != nil { - return NewJsonRPCErrorResponse(*req.ID, -1, err.Error(), nil), nil - } - return NewJsonRPCResponse(*req.ID, result), nil - } - return nil -} - -// high level function to send a request -func (g *GuestRPCManager) SendRequest(method string, params interface{}) (interface{}, error) { - req := NewJsonRPCRequest(method, params) - resp, err := g.SendJsonRequest(req) - if err != nil { - return nil, err - } - if resp.Error != nil { - return nil, fmt.Errorf("error: %v", resp.Error) - } - return resp.Result, nil -} - -// low level function to send a json request -func (g *GuestRPCManager) SendJsonRequest(req *JsonRPCRequest) (*JsonRPCResponse, error) { - if g.output == nil { - return nil, fmt.Errorf("output file not set") - } - newID := json.Number(fmt.Sprintf("%d", globalIDCounter)) - req.ID = &newID - globalIDCounter++ - - // set callback to unblock - ch := make(chan *JsonRPCResponse, 1) - g.SetRequestCallback(*req.ID, func(resp *JsonRPCResponse) error { - log.Debugf("Received response for request %s", *req.ID) - ch <- resp - return nil - }, true) - - if err := req.Send(g.output); err != nil { - return nil, err - } - - // wait for response - select { - case <-time.After(ResponseTimeout): - return nil, fmt.Errorf("timeout waiting for response") - case resp := <-ch: - return resp, nil - } -} - -func (g *GuestRPCManager) sendErrorJsonResponse(id json.Number, err error) error { - resp := NewJsonRPCErrorResponse(id, -1, err.Error(), nil) - if g.output == nil { - return fmt.Errorf("output file not set") - } - return resp.Send(g.output) -} - -func (g *GuestRPCManager) Run() { - // read from stdin - reader := g.input - - for { - // read a 64 bit uint - buf := make([]byte, 8) - if _, err := reader.Read(buf); err != nil { - log.Errorf("Error reading from stdin: %v", err) - continue - } - dataLen := binary.LittleEndian.Uint64(buf) - - if dataLen == 0 { - log.Infof("Exiting") - break - } - - log.Debugf("Got message size: %d", dataLen) - // read dataLen bytes - data := make([]byte, dataLen) - if _, err := io.ReadFull(reader, data); err != nil { - log.Errorf("Error reading from stdin: %v", err) - continue - } - - if len(data) == 0 { - log.Infof("Exiting") - break - } - - var req JsonRPCRequest - err := req.Unmarshal([]byte(data)) - if err == nil && g.reqHandler != nil { - if req.Method == nil { - log.Errorf("Invalid request: %v", req) - continue - } - - // check raw handler - if g.reqRawHandler != nil { - resp, err := g.reqRawHandler(&req) - if err != nil { - log.Errorf("Error handling request: %v", err) - } - if resp != nil { - if err = resp.Send(g.output); err != nil { - log.Errorf("Error sending response: %v", err) - } - continue - } - } - - // request is valid - if hdl, ok := g.reqHandler[*req.Method]; ok { - go func() { - if resp, err := hdl(&req); err != nil { - log.Errorf("Error handling request: %v", err) - if err = g.sendErrorJsonResponse(*req.ID, err); err != nil { - log.Errorf("Error sending error response: %v", err) - } - } else { - log.Debugf("Sending response for method %s", *req.Method) - if err = resp.Send(g.output); err != nil { - log.Errorf("Error sending response: %v", err) - } - } - }() - } else { - log.Infof("No handler for method %s", *req.Method) - if err = g.sendErrorJsonResponse(*req.ID, fmt.Errorf("method not found")); err != nil { - log.Errorf("Error sending error response: %v", err) - } - } - continue - } - - var resp JsonRPCResponse - err = resp.Unmarshal([]byte(data)) - if err == nil { - if g.respRawHandler != nil { - // response is valid - if err = g.respRawHandler(&resp); err != nil { - log.Errorf("Error handling response: %v", err) - } - } - - // find the id in the request - pendingRequestsMu.RLock() - callback, ok := pendingRequests[*resp.ID] - pendingRequestsMu.RUnlock() - if ok { - go func() { - if err = callback.cb(&resp); err != nil { - log.Errorf("Error handling response: %v", err) - } - if callback.autoClear { - g.ClearRequestCallback(*resp.ID) - } - }() - } - } - } -} diff --git a/pkg/rpc/json.go b/pkg/rpc/json.go deleted file mode 100644 index 87048da..0000000 --- a/pkg/rpc/json.go +++ /dev/null @@ -1,188 +0,0 @@ -package rpc - -import ( - "encoding/binary" - "encoding/json" - "fmt" - "io" -) - -type JsonRPCRequest struct { - Version string `json:"jsonrpc"` - Method *string `json:"method"` - Params interface{} `json:"params,omitempty"` - ID *json.Number `json:"id"` -} - -type JsonRPCNotification struct { - Version string `json:"jsonrpc"` - Method *string `json:"method"` - Params interface{} `json:"params"` -} - -type JsonRPCError struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data"` -} - -type JsonRPCResponse struct { - Version string `json:"jsonrpc"` - Result interface{} `json:"result,omitempty"` - Error *JsonRPCError `json:"error,omitempty"` - ID *json.Number `json:"id"` -} - -// create request -func NewJsonRPCRequest(method string, params interface{}) *JsonRPCRequest { - res := &JsonRPCRequest{ - Version: "2.0", - Method: &method, - Params: params, - } - res.ID = nil - return res -} - -func (r *JsonRPCRequest) Send(out io.Writer) error { - if r.ID == nil { - return fmt.Errorf("invalid request id") - } - b, err := r.Marshal() - if err != nil { - return err - } - - // send little endian length of b using uint64 - length := uint64(len(b)) - buf := make([]byte, 8) - binary.LittleEndian.PutUint64(buf, length) - _, err = out.Write(buf) - if err != nil { - return err - } - - // write b to output pipe - n, err := out.Write(b) - if err != nil { - return err - } - if n != len(b) { - return fmt.Errorf("error writing to output pipe") - } - return nil -} - -// Marshal operations - -func (r *JsonRPCRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *JsonRPCRequest) CreateSuccessResponse(result interface{}) *JsonRPCResponse { - return &JsonRPCResponse{ - Version: r.Version, - Result: result, - ID: r.ID, - } -} - -func (r *JsonRPCRequest) CreateErrorResponse(code int, message string, data interface{}) *JsonRPCResponse { - return &JsonRPCResponse{ - Version: r.Version, - Error: &JsonRPCError{ - Code: code, - Message: message, - Data: data, - }, - ID: r.ID, - } -} - -func (r *JsonRPCNotification) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *JsonRPCResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -// Unmarshal operations - -func (r *JsonRPCRequest) Unmarshal(data []byte) error { - err := json.Unmarshal(data, r) - if err != nil { - return err - } - if r.Method == nil || r.ID == nil { - return fmt.Errorf("invalid request") - } - return nil -} - -func (r *JsonRPCNotification) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -func NewJsonRPCResponse(id json.Number, result interface{}) *JsonRPCResponse { - return &JsonRPCResponse{ - Version: "2.0", - Result: result, - ID: &id, - } -} - -func NewJsonRPCErrorResponse(id json.Number, code int, message string, data interface{}) *JsonRPCResponse { - return &JsonRPCResponse{ - Version: "2.0", - Error: &JsonRPCError{ - Code: code, - Message: message, - Data: data, - }, - ID: &id, - } -} - -func (r *JsonRPCResponse) Unmarshal(data []byte) error { - err := json.Unmarshal(data, r) - if err != nil { - return err - } - if r.ID == nil { - return fmt.Errorf("invalid response") - } - if r.Error != nil && r.Result != nil { - return fmt.Errorf("invalid response") - } - return nil -} - -func (r *JsonRPCResponse) Send(out io.Writer) error { - if r.ID == nil { - return fmt.Errorf("invalid response id") - } - b, err := r.Marshal() - if err != nil { - return err - } - - // send little endian length of b using uint64 - length := uint64(len(b)) - buf := make([]byte, 8) - binary.LittleEndian.PutUint64(buf, length) - _, err = out.Write(buf) - if err != nil { - return err - } - - // write b to output pipe - n, err := out.Write(b) - if err != nil { - return err - } - if n != len(b) { - return fmt.Errorf("error writing to output pipe") - } - return nil -} diff --git a/pkg/rpc/payload/chat.go b/pkg/rpc/payload/chat.go deleted file mode 100644 index 7343d2a..0000000 --- a/pkg/rpc/payload/chat.go +++ /dev/null @@ -1,79 +0,0 @@ -package payload - -import ( - "encoding/json" -) - -type ChatCompletionRequest struct { - Messages []ChatMessageV2 `json:"messages"` - Model string `json:"model"` - ToolsetId string `json:"toolset_id"` -} - -// marshal operations -func (r *ChatCompletionRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -// unmarshal to ChatCompletionRequest -func (r *ChatCompletionRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type ChatMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type ChatCompletionResponse struct { - Id string `json:"id"` - Model string `json:"model"` - Choices []ChatChoice `json:"choices"` -} - -type ChatChoice struct { - Message ChatMessage `json:"message"` - Index json.Number `json:"index"` - Reason string `json:"finish_reason"` -} - -func (r *ChatCompletionResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *ChatCompletionResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type ChatMessageV2 struct { - Metadata map[string]interface{} `json:"metadata"` - Content string `json:"content"` -} - -type ChatCompletionResponseV2 struct { - Id string `json:"id"` - Model string `json:"model"` - Messages []ChatMessageV2 `json:"messages"` -} - -func (r *ChatCompletionResponseV2) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *ChatCompletionResponseV2) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type ChatCompletionRequestV2 struct { - Messages []ChatMessageV2 `json:"messages"` - Model string `json:"model"` - ToolsetId string `json:"toolset_id"` -} - -func (r *ChatCompletionRequestV2) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *ChatCompletionRequestV2) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} diff --git a/pkg/rpc/payload/const.go b/pkg/rpc/payload/const.go deleted file mode 100644 index 8199499..0000000 --- a/pkg/rpc/payload/const.go +++ /dev/null @@ -1,24 +0,0 @@ -package payload - -const ( - HostCallVectorStoreCreate = "vectorstore.create" - HostCallVectorStoreInsert = "vectorstore.insert" - HostCallVectorStoreSearch = "vectorstore.search" - HostCallVectorStoreDelete = "vectorstore.delete" - - HostCallMessagePassingRegister = "messagepassing.register" - HostCallMessagePassingUnregister = "messagepassing.unregister" - HostCallMessagePassingLookup = "messagepassing.lookup" - HostCallMessagePassingSend = "messagepassing.send" - - HostCallTransform = "transform" - HostCallTransformConfig = "transform.config" - - HostCallToolNew = "tool.new" - HostCallToolsetNew = "toolset.new" - HostCallToolsetInstallBuiltins = "toolset.install.builtins" - - HostCallInput = "input" - HostCallSpeak = "speak" - HostCallRecord = "record" -) diff --git a/pkg/rpc/payload/msgpassing.go b/pkg/rpc/payload/msgpassing.go deleted file mode 100644 index 44638de..0000000 --- a/pkg/rpc/payload/msgpassing.go +++ /dev/null @@ -1,101 +0,0 @@ -package payload - -import "encoding/json" - -type MessagePassingRegisterRequest struct { - Name string `json:"name"` - Method string `json:"method"` -} - -func (r *MessagePassingRegisterRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *MessagePassingRegisterRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type MessagePassingRegisterResponse struct { - MsgPassingId uint64 `json:"msg_passing_id"` -} - -func (r *MessagePassingRegisterResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *MessagePassingRegisterResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type MessagePassingUnregisterRequest struct { - MsgPassingId uint64 `json:"msg_passing_id"` -} - -func (r *MessagePassingUnregisterRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *MessagePassingUnregisterRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type MessagePassingUnregisterResponse struct { - MsgPassingId uint64 `json:"msg_passing_id"` -} - -func (r *MessagePassingUnregisterResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *MessagePassingUnregisterResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type MessagePassingLookupRequest struct { - Name string `json:"name"` -} - -func (r *MessagePassingLookupRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *MessagePassingLookupRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type MessagePassingLookupResponse struct { - MsgPassingId uint64 `json:"msg_passing_id"` -} - -func (r *MessagePassingLookupResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *MessagePassingLookupResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type MessagePassingSendRequest struct { - MsgPassingId uint64 `json:"msg_passing_id"` - Data []byte `json:"data"` -} - -func (r *MessagePassingSendRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *MessagePassingSendRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type MessagePassingSendResponse struct { - MsgPassingId uint64 `json:"msg_passing_id"` -} - -func (r *MessagePassingSendResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *MessagePassingSendResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} diff --git a/pkg/rpc/payload/tools.go b/pkg/rpc/payload/tools.go deleted file mode 100644 index a5ebee4..0000000 --- a/pkg/rpc/payload/tools.go +++ /dev/null @@ -1,87 +0,0 @@ -package payload - -import "encoding/json" - -type NewToolParam struct { - Name string `json:"name"` - Type string `json:"type"` - Description string `json:"description"` - Required bool `json:"required"` -} - -type NewToolRequest struct { - Name string `json:"name"` - Description string `json:"description"` - Params []NewToolParam `json:"params"` - Cb string `json:"cb"` -} - -type NewToolResponse struct { - Tid string `json:"tool_id"` -} - -type NewToolsetRequest struct { - Name string `json:"name"` - Description string `json:"description"` - ToolIds []string `json:"tool_ids"` -} - -type NewToolsetResponse struct { - Tsid string `json:"toolset_id"` -} - -type ToolsetInstallBuiltinsRequest struct { - Tsid string `json:"toolset_id"` -} - -type ToolsetInstallBuiltinsResponse struct { - Tsid string `json:"toolset_id"` -} - -func (r *NewToolRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *NewToolResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *NewToolsetRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *NewToolsetResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *NewToolRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -func (r *NewToolResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -func (r *NewToolsetRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -func (r *NewToolsetResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -func (r *ToolsetInstallBuiltinsRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *ToolsetInstallBuiltinsResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *ToolsetInstallBuiltinsRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -func (r *ToolsetInstallBuiltinsResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} diff --git a/pkg/rpc/payload/transform.go b/pkg/rpc/payload/transform.go deleted file mode 100644 index ca43a6c..0000000 --- a/pkg/rpc/payload/transform.go +++ /dev/null @@ -1,89 +0,0 @@ -package payload - -import "encoding/json" - -type TransformType int -type TransformOperation int - -const ( - TransformTypeImage TransformType = iota - TransformTypeText - TransformTypeAudio - TransformTypeVideo - TransformTypeTensor - TransformTypeVector - TransformTypeUnknown -) - -const ( - TransformOperationLLM TransformOperation = iota - TransformOperationTools - TransformOperationEmbeddings - TransformOperationOCR - TransformOperationTextToSpeech - TransformOperationSpeechToText - TransformOperationTextToImage -) - -// Transform request -type TransformRequest struct { - InputTypes []TransformType `json:"input_types"` - OutputTypes []TransformType `json:"output_types"` - Operations []TransformOperation `json:"operations"` - Params interface{} `json:"params"` -} - -type TransformResult struct { - Type TransformType `json:"type"` - Data interface{} `json:"data"` -} - -// Transform response -type TransformResponse struct { - Results []TransformResult `json:"results"` -} - -// marshal operations -func (r *TransformRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *TransformResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -// unmarshal to TransformRequest -func (r *TransformRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -func (r *TransformResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type TransformConfigRequest struct { - Test string `json:"test"` - Reset bool `json:"reset"` -} - -type TransformConfigResponse struct { - Result interface{} `json:"result"` -} - -// marshal operations -func (r *TransformConfigRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *TransformConfigResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -// unmarshal to TransformRequest -func (r *TransformConfigRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -func (r *TransformConfigResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} diff --git a/pkg/rpc/payload/transform/const.go b/pkg/rpc/payload/transform/const.go deleted file mode 100644 index c2c2ab8..0000000 --- a/pkg/rpc/payload/transform/const.go +++ /dev/null @@ -1,5 +0,0 @@ -package transform - -const ( - HostCallTransform = "transform" -) diff --git a/pkg/rpc/payload/transform/embeddings.go b/pkg/rpc/payload/transform/embeddings.go deleted file mode 100644 index 094d6b0..0000000 --- a/pkg/rpc/payload/transform/embeddings.go +++ /dev/null @@ -1,55 +0,0 @@ -package transform - -import ( - "encoding/json" - "fmt" -) - -type EmbeddingsRequest struct { - Input string `json:"input"` - Model string `json:"model"` -} - -func (r *EmbeddingsRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *EmbeddingsRequest) Unmarshal(data []byte) error { - err := json.Unmarshal(data, r) - if err != nil { - return err - } - if r.Input == "" || r.Model == "" { - return fmt.Errorf("invalid input or model") - } - return nil -} - -type EmbeddingsResponse struct { - Object string `json:"object"` - Data []EmbeddingObject `json:"data"` - Model string `json:"model"` - Usage interface{} `json:"usage"` -} - -func (r *EmbeddingsResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *EmbeddingsResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type EmbeddingObject struct { - Object string `json:"object"` - Embedding []float64 `json:"embedding"` - Index int `json:"index"` -} - -func (r *EmbeddingObject) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *EmbeddingObject) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} diff --git a/pkg/rpc/payload/transform/image_gen.go b/pkg/rpc/payload/transform/image_gen.go deleted file mode 100644 index 2fe89b5..0000000 --- a/pkg/rpc/payload/transform/image_gen.go +++ /dev/null @@ -1,36 +0,0 @@ -package transform - -import "encoding/json" - -type ImageGenerationRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - ResponseFormat string `json:"response_format"` -} - -func (r *ImageGenerationRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *ImageGenerationRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type ImageObject struct { - Url string `json:"url"` - B64Json string `json:"b64_json"` - RevisedPrompt string `json:"revised_prompt"` -} - -type ImageGenerationResponse struct { - Created json.Number `json:"created"` - Data []ImageObject `json:"data"` -} - -func (r *ImageGenerationResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *ImageGenerationResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} diff --git a/pkg/rpc/payload/transform/speech2text.go b/pkg/rpc/payload/transform/speech2text.go deleted file mode 100644 index f3cc785..0000000 --- a/pkg/rpc/payload/transform/speech2text.go +++ /dev/null @@ -1,28 +0,0 @@ -package transform - -import "encoding/json" - -type SpeechToTextRequest struct { - Model string `json:"model"` - Audio string `json:"audio"` -} - -func (r *SpeechToTextRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *SpeechToTextRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type SpeechToTextResponse struct { - Text string `json:"text"` -} - -func (r *SpeechToTextResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *SpeechToTextResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} diff --git a/pkg/rpc/payload/transform/text2speech.go b/pkg/rpc/payload/transform/text2speech.go deleted file mode 100644 index c121ae6..0000000 --- a/pkg/rpc/payload/transform/text2speech.go +++ /dev/null @@ -1,30 +0,0 @@ -package transform - -import "encoding/json" - -type TextToSpeechRequest struct { - Model string `json:"model"` - Input string `json:"input"` - Voice string `json:"voice"` - Format string `json:"response_format"` -} - -func (r *TextToSpeechRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *TextToSpeechRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type TextToSpeechResponse struct { - EncodedAudio string `json:"audio"` -} - -func (r *TextToSpeechResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *TextToSpeechResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} diff --git a/pkg/rpc/payload/vectorstore.go b/pkg/rpc/payload/vectorstore.go deleted file mode 100644 index 74a8b82..0000000 --- a/pkg/rpc/payload/vectorstore.go +++ /dev/null @@ -1,102 +0,0 @@ -package payload - -import "encoding/json" - -type VectorStoreCreateRequest struct { - Name string `json:"name"` - Dimentions uint64 `json:"dimentions"` -} - -func (r *VectorStoreCreateRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *VectorStoreCreateRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type VectorStoreCreateResponse struct { - VID int `json:"vid"` -} - -func (r *VectorStoreCreateResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *VectorStoreCreateResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type VectorStoreDeleteRequest struct { - VID int `json:"vid"` -} - -func (r *VectorStoreDeleteRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *VectorStoreDeleteRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type VectorStoreDeleteResponse struct { - VID int `json:"vid"` -} - -func (r *VectorStoreDeleteResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *VectorStoreDeleteResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type VectorStoreInsertRequest struct { - VID int `json:"vid"` - Vector []float32 `json:"vector"` - Data []byte `json:"data"` -} - -func (r *VectorStoreInsertRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *VectorStoreInsertRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type VectorStoreInsertResponse struct { - VID int `json:"vid"` -} - -func (r *VectorStoreInsertResponse) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *VectorStoreInsertResponse) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type VectorStoreSearchRequest struct { - VID int `json:"vid"` - Vector []float32 `json:"vector"` - Limit uint64 `json:"limit"` -} - -func (r *VectorStoreSearchRequest) Marshal() ([]byte, error) { - return json.Marshal(r) -} - -func (r *VectorStoreSearchRequest) Unmarshal(data []byte) error { - return json.Unmarshal(data, r) -} - -type VectorStoreSearchResponseEntry struct { - Vector []float32 `json:"vector"` - Data []byte `json:"data"` -} - -type VectorStoreSearchResponse struct { - VID int `json:"vid"` - Entries []VectorStoreSearchResponseEntry `json:"entries"` -} diff --git a/pkg/rpc/transform.go b/pkg/rpc/transform.go deleted file mode 100644 index 0872446..0000000 --- a/pkg/rpc/transform.go +++ /dev/null @@ -1,150 +0,0 @@ -package rpc - -import ( - "fmt" - - "github.com/lfedgeai/spear/pkg/rpc/payload" - "github.com/lfedgeai/spear/pkg/rpc/payload/transform" - "github.com/lfedgeai/spear/pkg/utils" -) - -func ChatCompletion(rpcMgr *GuestRPCManager, model string, msgs []payload.ChatMessageV2, tsId string) ([]payload.ChatMessageV2, error) { - req := &payload.ChatCompletionRequestV2{ - Model: model, - Messages: msgs, - ToolsetId: tsId, - } - resp, err := RPCManagerSendRequest[payload.TransformResponse](rpcMgr, transform.HostCallTransform, payload.TransformRequest{ - InputTypes: []payload.TransformType{payload.TransformTypeText}, - OutputTypes: []payload.TransformType{payload.TransformTypeText}, - Operations: []payload.TransformOperation{ - payload.TransformOperationLLM, - payload.TransformOperationTools, - }, - Params: req, - }) - if err != nil { - return nil, fmt.Errorf("error geting result %v", err) - } - - if len(resp.Results) != 1 { - return nil, fmt.Errorf("unexpected number of results: %d", len(resp.Results)) - } - - if resp.Results[0].Type != payload.TransformTypeText { - return nil, fmt.Errorf("unexpected result type: %v", resp.Results[0].Type) - } - - var chatResp payload.ChatCompletionResponseV2 - if err := utils.InterfaceToType(&chatResp, resp.Results[0].Data); err != nil { - return nil, fmt.Errorf("error converting response: %v", err) - } - - return chatResp.Messages, nil -} - -func TextToSpeech(rpcMgr *GuestRPCManager, model, voice, input, format string) (*transform.TextToSpeechResponse, error) { - req := &transform.TextToSpeechRequest{ - Model: model, - Voice: voice, - Input: input, - Format: format, - } - resp, err := RPCManagerSendRequest[payload.TransformResponse](rpcMgr, transform.HostCallTransform, payload.TransformRequest{ - InputTypes: []payload.TransformType{payload.TransformTypeText}, - OutputTypes: []payload.TransformType{payload.TransformTypeAudio}, - Operations: []payload.TransformOperation{ - payload.TransformOperationTextToSpeech, - }, - Params: req, - }) - if err != nil { - return nil, fmt.Errorf("error getting result: %v", err) - } - - if len(resp.Results) != 1 { - return nil, fmt.Errorf("unexpected number of results: %d", len(resp.Results)) - } - - if resp.Results[0].Type != payload.TransformTypeAudio { - return nil, fmt.Errorf("unexpected result type: %v", resp.Results[0].Type) - } - - var ttsResp transform.TextToSpeechResponse - if err := utils.InterfaceToType(&ttsResp, resp.Results[0].Data); err != nil { - return nil, fmt.Errorf("error converting response: %v", err) - } - - return &ttsResp, nil -} - -func Embeddings(rpcMgr *GuestRPCManager, model, input string) ([]float64, error) { - req := &transform.EmbeddingsRequest{ - Model: model, - Input: input, - } - resp, err := RPCManagerSendRequest[payload.TransformResponse](rpcMgr, transform.HostCallTransform, payload.TransformRequest{ - InputTypes: []payload.TransformType{payload.TransformTypeText}, - OutputTypes: []payload.TransformType{payload.TransformTypeVector}, - Operations: []payload.TransformOperation{ - payload.TransformOperationEmbeddings, - }, - Params: req, - }) - if err != nil { - return nil, fmt.Errorf("error getting result: %v", err) - } - - if len(resp.Results) != 1 { - return nil, fmt.Errorf("unexpected number of results: %d", len(resp.Results)) - } - - if resp.Results[0].Type != payload.TransformTypeVector { - return nil, fmt.Errorf("unexpected result type: %v", resp.Results[0].Type) - } - - var embResp transform.EmbeddingsResponse - if err := utils.InterfaceToType(&embResp, resp.Results[0].Data); err != nil { - return nil, fmt.Errorf("error converting response: %v", err) - } - - if len(embResp.Data) != 1 { - return nil, fmt.Errorf("unexpected number of embeddings: %d", len(embResp.Data)) - } - - return embResp.Data[0].Embedding, nil -} - -func TextToImage(rpcMgr *GuestRPCManager, model, prompt, format string) (*transform.ImageGenerationResponse, error) { - req := &transform.ImageGenerationRequest{ - Model: model, - Prompt: prompt, - ResponseFormat: format, - } - resp, err := RPCManagerSendRequest[payload.TransformResponse](rpcMgr, transform.HostCallTransform, payload.TransformRequest{ - InputTypes: []payload.TransformType{payload.TransformTypeText}, - OutputTypes: []payload.TransformType{payload.TransformTypeImage}, - Operations: []payload.TransformOperation{ - payload.TransformOperationTextToImage, - }, - Params: req, - }) - if err != nil { - return nil, fmt.Errorf("error getting result: %v", err) - } - - if len(resp.Results) != 1 { - return nil, fmt.Errorf("unexpected number of results: %d", len(resp.Results)) - } - - if resp.Results[0].Type != payload.TransformTypeImage { - return nil, fmt.Errorf("unexpected result type: %v", resp.Results[0].Type) - } - - var ttiResp transform.ImageGenerationResponse - if err := utils.InterfaceToType(&ttiResp, resp.Results[0].Data); err != nil { - return nil, fmt.Errorf("error converting response: %v", err) - } - - return &ttiResp, nil -} diff --git a/pkg/tools/docker/docker.go b/pkg/tools/docker/docker.go index e204198..33958df 100644 --- a/pkg/tools/docker/docker.go +++ b/pkg/tools/docker/docker.go @@ -5,15 +5,15 @@ import ( "time" "github.com/lfedgeai/spear/pkg/common" - "github.com/lfedgeai/spear/worker" - "github.com/lfedgeai/spear/worker/task/docker" + "github.com/lfedgeai/spear/spearlet" + "github.com/lfedgeai/spear/spearlet/task/docker" log "github.com/sirupsen/logrus" "github.com/docker/docker/api/types/container" ) type TestSetup struct { - w *worker.Worker + w *spearlet.Spearlet vecStore *container.CreateResponse } @@ -44,8 +44,8 @@ func NewTestSetup() *TestSetup { t.startVectorStoreContainer() // setup the test environment - cfg := worker.NewServeWorkerConfig("localhost", "8080", []string{}, true, common.SpearPlatformAddress) - t.w = worker.NewWorker(cfg) + cfg := spearlet.NewServeSpearletConfig("localhost", "8080", []string{}, true, common.SpearPlatformAddress) + t.w = spearlet.NewSpearlet(cfg) t.w.Initialize() go t.w.StartServer() time.Sleep(5 * time.Second) diff --git a/pkg/utils/protohelper/helper.go b/pkg/utils/protohelper/helper.go new file mode 100644 index 0000000..12c261f --- /dev/null +++ b/pkg/utils/protohelper/helper.go @@ -0,0 +1,141 @@ +package protohelper + +import ( + "fmt" + + flatbuffers "github.com/google/flatbuffers/go" + "github.com/lfedgeai/spear/pkg/spear/proto/transform" + "github.com/lfedgeai/spear/pkg/spear/proto/transport" +) + +type IfWithInit[T any] interface { + *T + Init([]byte, flatbuffers.UOffsetT) +} + +func UnwrapTransformRequest[T any, P IfWithInit[T]](d P, s *transform.TransformRequest) error { + if d == nil { + return fmt.Errorf("destination is nil") + } + if s == nil { + return fmt.Errorf("source is nil") + } + tbl := flatbuffers.Table{} + if !s.Params(&tbl) { + return fmt.Errorf("error getting params") + } + d.Init(tbl.Bytes, tbl.Pos) + return nil +} + +func CreateErrorTransportResponse(id int64, code int, + msg string) *transport.TransportResponse { + builder := flatbuffers.NewBuilder(0) + msgOff := builder.CreateString(msg) + + transport.TransportResponseStart(builder) + transport.TransportResponseAddId(builder, id) + transport.TransportResponseAddCode(builder, int32(code)) + transport.TransportResponseAddMessage(builder, msgOff) + respOff := transport.TransportResponseEnd(builder) + builder.Finish(respOff) + + resp := transport.GetRootAsTransportResponse(builder.FinishedBytes(), 0) + return resp +} + +func TransportResponseToRaw(resp *transport.TransportResponse) ([]byte, error) { + if resp == nil { + return nil, fmt.Errorf("error in TransportResponseToRaw") + } + builder := flatbuffers.NewBuilder(0) + respOff := builder.CreateByteVector(resp.ResponseBytes()) + msgOff := builder.CreateString(string(resp.Message())) + + transport.TransportResponseStart(builder) + transport.TransportResponseAddId(builder, resp.Id()) + transport.TransportResponseAddCode(builder, resp.Code()) + transport.TransportResponseAddMessage(builder, msgOff) + transport.TransportResponseAddResponse(builder, respOff) + respOff = transport.TransportResponseEnd(builder) + + transport.TransportMessageRawStart(builder) + transport.TransportMessageRawAddDataType(builder, + transport.TransportMessageRaw_DataTransportResponse) + transport.TransportMessageRawAddData(builder, respOff) + raw := transport.TransportMessageRawEnd(builder) + + builder.Finish(raw) + + data := builder.FinishedBytes() + return data, nil +} + +func RPCSignalToRaw(method transport.Signal, data []byte) ([]byte, error) { + builder := flatbuffers.NewBuilder(0) + dataOff := builder.CreateByteVector(data) + + transport.TransportSignalStart(builder) + transport.TransportSignalAddMethod(builder, method) + transport.TransportSignalAddPayload(builder, dataOff) + signalOff := transport.TransportSignalEnd(builder) + + transport.TransportMessageRawStart(builder) + transport.TransportMessageRawAddDataType(builder, + transport.TransportMessageRaw_DataTransportSignal) + transport.TransportMessageRawAddData(builder, signalOff) + raw := transport.TransportMessageRawEnd(builder) + + builder.Finish(raw) + + res := builder.FinishedBytes() + return res, nil +} + +func RPCBufferResquestToRaw(id int64, method transport.Method, + req_buffer []byte) ([]byte, error) { + if len(req_buffer) == 0 { + return nil, fmt.Errorf("error in RPCBufferResponseToRaw") + } + builder := flatbuffers.NewBuilder(len(req_buffer) + 512) + reqBytesOff := builder.CreateByteVector(req_buffer) + + transport.TransportRequestStart(builder) + transport.TransportRequestAddId(builder, id) + transport.TransportRequestAddMethod(builder, method) + transport.TransportRequestAddRequest(builder, reqBytesOff) + reqOff := transport.TransportRequestEnd(builder) + + transport.TransportMessageRawStart(builder) + transport.TransportMessageRawAddDataType(builder, + transport.TransportMessageRaw_DataTransportRequest) + transport.TransportMessageRawAddData(builder, reqOff) + raw := transport.TransportMessageRawEnd(builder) + builder.Finish(raw) + + data := builder.FinishedBytes() + return data, nil +} + +func RPCBufferResponseToRaw(id int64, resp_buffer []byte) ([]byte, error) { + if len(resp_buffer) == 0 { + return nil, fmt.Errorf("error in RPCBufferResponseToRaw") + } + builder := flatbuffers.NewBuilder(len(resp_buffer) + 512) + respBytesOff := builder.CreateByteVector(resp_buffer) + + transport.TransportResponseStart(builder) + transport.TransportResponseAddId(builder, id) + transport.TransportResponseAddResponse(builder, respBytesOff) + respOff := transport.TransportResponseEnd(builder) + + transport.TransportMessageRawStart(builder) + transport.TransportMessageRawAddDataType(builder, + transport.TransportMessageRaw_DataTransportResponse) + transport.TransportMessageRawAddData(builder, respOff) + raw := transport.TransportMessageRawEnd(builder) + builder.Finish(raw) + + data := builder.FinishedBytes() + return data, nil +} diff --git a/proto/chat/common.fbs b/proto/chat/common.fbs new file mode 100644 index 0000000..b9c339e --- /dev/null +++ b/proto/chat/common.fbs @@ -0,0 +1,26 @@ +namespace spear.proto.chat; + +enum Role:byte { + System = 0, + User = 1, + Assistant = 2, + Developer = 3, + Other = 4, +} + +enum Reason:byte { + ToolCalls = 0, + Length = 1, + Stop = 2, + Other = 3, +} + +table ChatMetadata { + role: Role; + reason: Reason; +} + +table ChatMessage { + metadata: ChatMetadata (required); + content: string (required); +} diff --git a/proto/chat/request.fbs b/proto/chat/request.fbs new file mode 100644 index 0000000..bab2ebc --- /dev/null +++ b/proto/chat/request.fbs @@ -0,0 +1,17 @@ +include "common.fbs"; +include "tool/common.fbs"; + +namespace spear.proto.chat; + +table ToolInfo { + data: spear.proto.tool.ToolInfo; +} + +table ChatCompletionRequest { + messages: [ChatMessage] (required); + model: string (required); + tools: [ToolInfo]; + return_on_toolcall: bool = false; +} + +root_type ChatCompletionRequest; diff --git a/proto/chat/response.fbs b/proto/chat/response.fbs new file mode 100644 index 0000000..69ff8cd --- /dev/null +++ b/proto/chat/response.fbs @@ -0,0 +1,11 @@ +include "common.fbs"; + +namespace spear.proto.chat; + +table ChatCompletionResponse { + code: int = 0; + error: string; + messages: [ChatMessage] (required); +} + +root_type ChatCompletionResponse; diff --git a/proto/custom/request.fbs b/proto/custom/request.fbs new file mode 100644 index 0000000..a46dbcf --- /dev/null +++ b/proto/custom/request.fbs @@ -0,0 +1,8 @@ +namespace spear.proto.custom; + +table CustomRequest { + method_str: string (required); + params_str: string (required); +} + +root_type CustomRequest; \ No newline at end of file diff --git a/proto/custom/response.fbs b/proto/custom/response.fbs new file mode 100644 index 0000000..02c5734 --- /dev/null +++ b/proto/custom/response.fbs @@ -0,0 +1,7 @@ +namespace spear.proto.custom; + +table CustomResponse { + data: [ubyte] (required); +} + +root_type CustomResponse; \ No newline at end of file diff --git a/proto/io/input-req.fbs b/proto/io/input-req.fbs new file mode 100644 index 0000000..1e22dc1 --- /dev/null +++ b/proto/io/input-req.fbs @@ -0,0 +1,9 @@ + +namespace spear.proto.io; + +table InputRequest { + prompt: string (required); + dryrun: bool=false; +} + +root_type InputRequest; diff --git a/proto/io/input-resp.fbs b/proto/io/input-resp.fbs new file mode 100644 index 0000000..88c677f --- /dev/null +++ b/proto/io/input-resp.fbs @@ -0,0 +1,8 @@ + +namespace spear.proto.io; + +table InputResponse { + text: string (required); +} + +root_type InputResponse; diff --git a/proto/io/record-req.fbs b/proto/io/record-req.fbs new file mode 100644 index 0000000..3c79bc8 --- /dev/null +++ b/proto/io/record-req.fbs @@ -0,0 +1,8 @@ +namespace spear.proto.io; + +table RecordRequest { + prompt: string (required); + model: string; +} + +root_type RecordRequest; \ No newline at end of file diff --git a/proto/io/record-resp.fbs b/proto/io/record-resp.fbs new file mode 100644 index 0000000..bc0e761 --- /dev/null +++ b/proto/io/record-resp.fbs @@ -0,0 +1,7 @@ +namespace spear.proto.io; + +table RecordResponse { + text: string (required); +} + +root_type RecordResponse; \ No newline at end of file diff --git a/proto/io/speak-req.fbs b/proto/io/speak-req.fbs new file mode 100644 index 0000000..12238d9 --- /dev/null +++ b/proto/io/speak-req.fbs @@ -0,0 +1,11 @@ + +namespace spear.proto.io; + +table SpeakRequest { + text: string (required); + model: string; + voice: string; + format: string; +} + +root_type SpeakRequest; diff --git a/proto/io/speak-resp.fbs b/proto/io/speak-resp.fbs new file mode 100644 index 0000000..8db0cd7 --- /dev/null +++ b/proto/io/speak-resp.fbs @@ -0,0 +1,6 @@ + +namespace spear.proto.io; + +table SpeakResponse { + data: string (required); +} diff --git a/proto/tool/common.fbs b/proto/tool/common.fbs new file mode 100644 index 0000000..3576c91 --- /dev/null +++ b/proto/tool/common.fbs @@ -0,0 +1,47 @@ +namespace spear.proto.tool; + +enum BuiltinToolID: uint16 { + Invalid = 0, + Datetime = 1, + Sleep = 2, + SearchContactEmail = 3, + // email tools + ListOpenEmails = 4, + ComposeEmail = 5, + SendEmailDraftWindow = 6, + // mouse tools + MouseRightClick = 7, + MouseLeftClick = 8, + // phone tools + PhoneCall = 9, + // screen tools + FullScreenshot = 10, + // web tools + OpenURL = 11, + ScrollDown = 12, + ScrollUp = 13, + PageDown = 14, + PageUp = 15, + WebScreenshot = 16, + + Max = 17 +} + +table BuiltinToolInfo { + tool_id: uint16; +} + +table NormalToolInfo { + workload_id: [ubyte]; + tool_id: uint16; +} + +table InternalToolInfo { + tool_id: uint16; +} + +union ToolInfo { + BuiltinToolInfo, + NormalToolInfo, + InternalToolInfo, +} diff --git a/proto/tool/internal-req.fbs b/proto/tool/internal-req.fbs new file mode 100644 index 0000000..55383c5 --- /dev/null +++ b/proto/tool/internal-req.fbs @@ -0,0 +1,18 @@ +include "common.fbs"; + +namespace spear.proto.tool; + +table InternalToolCreateParamSpec { + name: string (required); + description: string; + type: string (required); + required: bool; +} + +table InternalToolCreateRequest { + name: string (required); + description: string; + params: [InternalToolCreateParamSpec]; +} + +root_type InternalToolCreateRequest; diff --git a/proto/tool/internal-resp.fbs b/proto/tool/internal-resp.fbs new file mode 100644 index 0000000..0cf3e0b --- /dev/null +++ b/proto/tool/internal-resp.fbs @@ -0,0 +1,9 @@ +include "common.fbs"; + +namespace spear.proto.tool; + +table InternalToolCreateResponse { + tool_id: int64; +} + +root_type InternalToolCreateResponse; diff --git a/proto/tool/invoke-req.fbs b/proto/tool/invoke-req.fbs new file mode 100644 index 0000000..a35e411 --- /dev/null +++ b/proto/tool/invoke-req.fbs @@ -0,0 +1,16 @@ +include "common.fbs"; +namespace spear.proto.tool; + + +table Param { + key: string; + value: string; +} + +table ToolInvocationRequest { + tool_info: ToolInfo; + tool_name: string; + params: [Param]; +} + +root_type ToolInvocationRequest; diff --git a/proto/tool/invoke-resp.fbs b/proto/tool/invoke-resp.fbs new file mode 100644 index 0000000..572e17b --- /dev/null +++ b/proto/tool/invoke-resp.fbs @@ -0,0 +1,7 @@ +namespace spear.proto.tool; + +table ToolInvocationResponse { + result: string (required); +} + +root_type ToolInvocationResponse; diff --git a/proto/transform/common.fbs b/proto/transform/common.fbs new file mode 100644 index 0000000..96f0f42 --- /dev/null +++ b/proto/transform/common.fbs @@ -0,0 +1,24 @@ +include "../chat/request.fbs"; +include "../chat/response.fbs"; + +namespace spear.proto.transform; + +enum TransformType : int { + Image, + Text, + Audio, + Video, + Tensor, + Vector, + Unknown +} + +enum TransformOperation : int { + LLM, + Tools, + Embeddings, + OCR, + TextToSpeech, + SpeechToText, + TextToImage +} diff --git a/proto/transform/request.fbs b/proto/transform/request.fbs new file mode 100644 index 0000000..f3d6d9c --- /dev/null +++ b/proto/transform/request.fbs @@ -0,0 +1,23 @@ +include "../chat/request.fbs"; +include "../chat/response.fbs"; +include "common.fbs"; + +namespace spear.proto.transform; + +table TranformRequest_ParamsRaw { + data: string (required); +} + +union TransformRequest_Params { + spear.proto.chat.ChatCompletionRequest, + TranformRequest_ParamsRaw +} + +table TransformRequest { + input_types: [TransformType] (required); + output_types: [TransformType] (required); + operations: [TransformOperation] (required); + params: TransformRequest_Params (required); +} + +root_type TransformRequest; diff --git a/proto/transform/response.fbs b/proto/transform/response.fbs new file mode 100644 index 0000000..412c8e6 --- /dev/null +++ b/proto/transform/response.fbs @@ -0,0 +1,20 @@ +include "../chat/request.fbs"; +include "../chat/response.fbs"; +include "common.fbs"; + +namespace spear.proto.transform; + +table TranformResponse_DataRaw { + data: string (required); +} + +union TransformResponse_Data { + spear.proto.chat.ChatCompletionResponse, + TranformResponse_DataRaw +} + +table TransformResponse { + data: TransformResponse_Data; +} + +root_type TransformResponse; diff --git a/proto/transport/request.fbs b/proto/transport/request.fbs new file mode 100644 index 0000000..3815d79 --- /dev/null +++ b/proto/transport/request.fbs @@ -0,0 +1,31 @@ +include "../transform/request.fbs"; + +namespace spear.proto.transport; + +enum Method : uint32 { + Unknown = 0, + Transform, + TransformConfig, + // io related + Input, + Speak, + Record, + // tool and toolset + ToolInvoke, + InternalToolCreate, + // vec store + VecStoreCreate, + VecStoreInsert, + VecStoreQuery, + VecStoreDelete, + // Custom + Custom, +} + +table TransportRequest { + id: int64 = -1; // negative id means no need to response + method: Method = Unknown; + request: [ubyte]; +} + +root_type TransportRequest; diff --git a/proto/transport/response.fbs b/proto/transport/response.fbs new file mode 100644 index 0000000..4c9ac31 --- /dev/null +++ b/proto/transport/response.fbs @@ -0,0 +1,12 @@ +include "../transform/response.fbs"; + +namespace spear.proto.transport; + +table TransportResponse { + id: int64 = -1; // negative id is not valid + code: int32 = 0; + message: string; + response: [ubyte]; +} + +root_type TransportResponse; \ No newline at end of file diff --git a/proto/transport/signal.fbs b/proto/transport/signal.fbs new file mode 100644 index 0000000..5914703 --- /dev/null +++ b/proto/transport/signal.fbs @@ -0,0 +1,15 @@ +include "../transform/request.fbs"; + +namespace spear.proto.transport; + +enum Signal : uint32 { + Unknown = 0, + Terminate, +} + +table TransportSignal { + method: Signal = Unknown; + payload: [ubyte]; +} + +root_type TransportSignal; diff --git a/proto/transport/transport.fbs b/proto/transport/transport.fbs new file mode 100644 index 0000000..8385278 --- /dev/null +++ b/proto/transport/transport.fbs @@ -0,0 +1,17 @@ +include "request.fbs"; +include "response.fbs"; +include "signal.fbs"; + +namespace spear.proto.transport; + +union TransportMessageRaw_Data { + TransportRequest, + TransportResponse, + TransportSignal, +} + +table TransportMessageRaw { + data: TransportMessageRaw_Data (required); +} + +root_type TransportMessageRaw; diff --git a/sdk/cpp/Makefile b/sdk/cpp/Makefile new file mode 100644 index 0000000..3703c75 --- /dev/null +++ b/sdk/cpp/Makefile @@ -0,0 +1,18 @@ +.PHONY: all build clean + +CURRENT_DIR := $(shell pwd) +REPO_ROOT := $(shell git rev-parse --show-toplevel) + +all: include/proto + +include/proto: + allfiles=`find ${REPO_ROOT}/proto -name "*.fbs"`; \ + flatc -o ${CURRENT_DIR}/include/proto -I ${REPO_ROOT}/proto --cpp $${allfiles} + +clean: + rm -rf ${CURRENT_DIR}/include/proto; + +test: + @echo "No test for cpp sdk" + +.PHONY: all build clean test diff --git a/sdk/python/Makefile b/sdk/python/Makefile index a5a4186..a800d01 100644 --- a/sdk/python/Makefile +++ b/sdk/python/Makefile @@ -1,11 +1,29 @@ .PHONY: all build clean CURRENT_DIR := $(shell pwd) +REPO_ROOT := $(shell git rev-parse --show-toplevel) all: clean build -build: +build: spear/proto + python3 -m pip install -r requirements.txt; \ python3 -m build +spear/proto: + allfiles=`find ${REPO_ROOT}/proto -name "*.fbs"`; \ + flatc -o ${CURRENT_DIR}/ -I ${REPO_ROOT}/proto --python --python-typing $${allfiles} + clean: - rm -rf $(CURRENT_DIR)/dist $(CURRENT_DIR)/spear.egg-info + rm -rf ${CURRENT_DIR}/spear/proto && \ + rm -rf $(CURRENT_DIR)/dist $(CURRENT_DIR)/spear.egg-info && \ + find $(CURRENT_DIR) | grep -E "(__pycache__|\.pyc$$)" | xargs rm -rf + +install: build + pip uninstall spear -y; \ + pip install $(CURRENT_DIR)/dist/spear-*.whl + +uninstall: + pip uninstall spear -y + +test: build + PYTHONPATH=$(CURRENT_DIR) pytest --log-cli-level=DEBUG -s tests/ diff --git a/sdk/python/project.toml b/sdk/python/project.toml index 07de284..b15ead1 100644 --- a/sdk/python/project.toml +++ b/sdk/python/project.toml @@ -1,3 +1,3 @@ [build-system] -requires = ["setuptools", "wheel"] +requires = ["setuptools", "wheel", "pytest"] build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/sdk/python/requirements.txt b/sdk/python/requirements.txt new file mode 100644 index 0000000..11e8785 --- /dev/null +++ b/sdk/python/requirements.txt @@ -0,0 +1,4 @@ +setuptools +wheel +pytest +pytest-print \ No newline at end of file diff --git a/sdk/python/setup.py b/sdk/python/setup.py index 7784269..c20a6b9 100644 --- a/sdk/python/setup.py +++ b/sdk/python/setup.py @@ -2,7 +2,7 @@ setup( name="spear", - version="0.1", + version="0.0.1", description="Spear Python SDK", author="Wilson Wang", author_email="wilson.wang@bytedance.com", @@ -13,5 +13,13 @@ #dependencies install_requires=[ "dataclasses-json", + "flatbuffers", + "numpy", + ], + # packages for building + setup_requires=[ + "setuptools", + "wheel", + "pytest", ], ) diff --git a/sdk/python/spear/client.py b/sdk/python/spear/client.py index 0274669..ad28d6c 100644 --- a/sdk/python/spear/client.py +++ b/sdk/python/spear/client.py @@ -1,4 +1,4 @@ -import json +#!/usr/bin/env python3 import logging import os import queue @@ -7,134 +7,21 @@ import struct import threading import time +import traceback +from typing import Callable -import spear.hostcalls as hc +import flatbuffers as fbs -RPC_TYPE_REQ = 0 -RPC_TYPE_RESP_OK = 1 -RPC_TYPE_RESP_ERR = 2 -RPC_TYPE_RESP_NULL_OK = 3 +from spear.proto.custom import CustomRequest +from spear.proto.transport import (Method, TransportMessageRaw, + TransportMessageRaw_Data, TransportRequest, + TransportResponse, TransportSignal) MAX_INFLIGHT_REQUESTS = 128 - +DEFAULT_MESSAGE_SIZE = 4096 logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -def rpc_type(obj): - """ - determine the type of the rpc object - """ - if "method" in obj: - return RPC_TYPE_REQ - elif "result" in obj: - return RPC_TYPE_RESP_OK - elif "error" in obj: - return RPC_TYPE_RESP_ERR - else: - # empty return - # example: {'jsonrpc': '2.0', 'id': 5} - return RPC_TYPE_RESP_NULL_OK - - -class JsonRpcRequest(object): - """ - JsonRpcRequest is the request object for the rpc call - """ - - def __init__(self, rid, method, params): - self._method = method - self._params = params - self._id = rid - - def to_dict(self): - obj = {} - obj["method"] = self._method - obj["params"] = self._params - obj["id"] = self._id - return obj - - def build_response(self, result): - return JsonRpcOkResp(self._id, result) - - def build_error(self, code, message, data=None): - return JsonRpcErrorResp(self._id, code, message, data) - - @property - def method(self): - return self._method - - @property - def params(self): - return self._params - - @property - def id(self): - return self._id - - -class JsonRpcOkResp(object): - """ - JsonRpcOkResp is the response object for the successful rpc call - """ - - def __init__(self, rid, result): - self._result = result - self._id = rid - - def to_dict(self): - obj = {} - obj["result"] = self._result - obj["id"] = self._id - return obj - - @property - def result(self): - return self._result - - @property - def id(self): - return self._id - - -class JsonRpcErrorResp(object): - """ - JsonRpcErrorResp is the response object for the failed rpc call - """ - - def __init__(self, rid, code, message, data=None): - self._code = code - self._message = message - self._data = data - self._id = rid - - def to_dict(self): - """ - convert the object to dictionary - """ - obj = {} - obj["code"] = self._code - obj["message"] = self._message - obj["data"] = self._data - obj["id"] = self._id - return obj - - @property - def code(self): - return self._code - - @property - def message(self): - return self._message - - @property - def data(self): - return self._data - - @property - def id(self): - return self._id +logger.setLevel(logging.DEBUG) class HostAgent(object): @@ -145,10 +32,9 @@ class HostAgent(object): _instance = None def __init__(self): - self._client = None self._send_queue = None self._recv_queue = None - self._global_id = 0 + self._global_id = 1 self._send_task = None self._send_task_pipe_r, self._send_task_pipe_w = os.pipe() self._recv_task = None @@ -161,8 +47,9 @@ def __init__(self): self._inflight_requests_count = 0 self._pending_requests = {} self._pending_requests_lock = threading.Lock() + self._client = None - def __new__(cls): + def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super(HostAgent, cls).__new__(cls) return cls._instance @@ -172,13 +59,13 @@ def connect_host(self, host_addr: str, host_secret: int) -> socket: create a tcp connection to the server """ client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._client = client # convert the address to tuple host_addr = host_addr.split(":") host_addr = (host_addr[0], int(host_addr[1])) client.connect(host_addr) # send little endian secret 64-bit integer client.send(struct.pack(" MAX_INFLIGHT_REQUESTS: self._put_rpc_error( - rpc_data.id, + req.Id(), -32000, "Too many requests", "Too many requests", @@ -250,20 +158,41 @@ def handle_worker(handler, rpc_data): else: # create a thread to handle the request threading.Thread( - target=handle_worker, args=(handler, rpc_data) + target=handle_worker, + args=( + handler, + req.Id(), + custom_req.ParamsStr().decode("utf-8"), + ), ).start() - elif isinstance(rpc_data, JsonRpcOkResp) or isinstance( - rpc_data, JsonRpcErrorResp + elif ( + rpc_data.DataType() + == TransportMessageRaw_Data.TransportMessageRaw_Data.TransportResponse ): + # handle the response + # convert from TransportMessageRaw to TransportResponse + resp = TransportResponse.TransportResponse() + resp.Init(rpc_data.Data().Bytes, rpc_data.Data().Pos) with self._pending_requests_lock: - req = self._pending_requests.get(rpc_data.id) - if req is None: - logger.error("Invalid response id: %d", rpc_data.id) + if resp.Id() not in self._pending_requests: + logger.error("Invalid response id: %d", resp.Id()) else: - req["cb"](rpc_data) - del self._pending_requests[rpc_data.id] + req = self._pending_requests[resp.Id()] + req["cb"](resp) + del self._pending_requests[resp.Id()] + elif ( + rpc_data.DataType() + == TransportMessageRaw_Data.TransportMessageRaw_Data.TransportSignal + ): + sig = TransportSignal.TransportSignal() + sig.Init(rpc_data.Data().Bytes, rpc_data.Data().Pos) + if sig.Method() == Method.Method.Terminate: + logger.info("Terminating the agent") + self.stop() + return else: logger.error("Invalid rpc data") + raise ValueError("Invalid rpc data") def register_handler(self, method, handler): """ @@ -277,50 +206,32 @@ def unregister_handler(self, method): """ del self._handlers[method] - def _put_raw_object(self, obj): + def _put_raw_object(self, data: bytes): """ finalize the data and add it to the outgoing queue """ - logger.debug("Putting raw data to queue: %s", str(obj)) - json_data = json.dumps(obj, ensure_ascii=False, cls=hc.EnhancedJSONEncoder) - self._send_queue.put(json_data) + self._send_queue.put(data) os.write(self._send_task_pipe_w, b"\x01") - def _get_raw_object(self): + def _get_raw_data(self): """ - get the object from the incoming queue + get the data from the incoming queue """ - obj = self._recv_queue.get() - return obj - - def _get_rpc_data(self): - obj = self._get_raw_object() - if "jsonrpc" not in obj: - raise TypeError("Invalid jsonrpc version") - if obj["jsonrpc"] != "2.0": - raise TypeError("Invalid jsonrpc version") - if "id" not in obj: - raise TypeError("Invalid jsonrpc id") - rtype = rpc_type(obj) - if rtype == RPC_TYPE_REQ: - return JsonRpcRequest(obj["id"], obj["method"], obj["params"]) - elif rtype == RPC_TYPE_RESP_OK: - return JsonRpcOkResp(obj["id"], obj["result"]) - elif rtype == RPC_TYPE_RESP_ERR: - return JsonRpcErrorResp( - obj["id"], - obj["error"]["code"], - obj["error"]["message"], - obj["error"]["data"], + return self._recv_queue.get() + + def _get_rpc_data(self) -> TransportMessageRaw.TransportMessageRaw: + trans_resp = ( + TransportMessageRaw.TransportMessageRaw.GetRootAsTransportMessageRaw( + self._get_raw_data() ) - elif rtype == RPC_TYPE_RESP_NULL_OK: - return JsonRpcOkResp(obj["id"], None) - else: - raise TypeError("Invalid rpc object") + ) + if not isinstance(trans_resp, TransportMessageRaw.TransportMessageRaw): + raise ValueError("Invalid rpc data") + return trans_resp - def exec_request(self, method, param): + def exec_request(self, method: int, req_buf: bytes): """ - send the rpc request and return the response + send the rpc request and return the response as numpy array """ # create mutex mutex = threading.Lock() @@ -329,73 +240,132 @@ def exec_request(self, method, param): # create a list to store the response response = [] - def cb(rpc_data): + def cb(rpc_data: TransportResponse.TransportResponse): with mutex: response.append(rpc_data) cond.notify() - self._put_rpc_request(method, param, cb) + self._put_rpc_request(method, req_buf, cb) with mutex: cond.wait() - return response[0] - - def _put_rpc_request(self, method, param, cb): - obj = {} - obj["id"] = self._global_id + resp = response[0] + if resp.Code() != 0: + raise Exception(resp.Message()) + return resp.ResponseAsNumpy() + + def _put_rpc_request( + self, + method: int, + req_buf: bytes, + cb: Callable[[TransportResponse.TransportResponse], None], + ): + new_id = self._global_id self._global_id += 1 - obj["jsonrpc"] = "2.0" - obj["method"] = method - obj["params"] = param + builder = fbs.Builder(len(req_buf) + 1024) + req_buf_off = builder.CreateByteVector(req_buf) + + TransportRequest.Start(builder) + TransportRequest.AddId(builder, new_id) + TransportRequest.AddMethod(builder, method) + TransportRequest.AddRequest(builder, req_buf_off) + req_off = TransportRequest.End(builder) + + TransportMessageRaw.TransportMessageRawStart(builder) + TransportMessageRaw.AddDataType( + builder, TransportMessageRaw_Data.TransportMessageRaw_Data.TransportRequest + ) + TransportMessageRaw.AddData(builder, req_off) + msg_off = TransportMessageRaw.End(builder) + builder.Finish(msg_off) + + data = builder.Output() with self._pending_requests_lock: - self._pending_requests[obj["id"]] = { + self._pending_requests[new_id] = { "time": time.time(), - "obj": obj, + "obj": data, "cb": cb, } - self._put_raw_object(obj) - - def _put_rpc_response(self, req_id, result): - obj = {} - obj["id"] = req_id - obj["jsonrpc"] = "2.0" - obj["result"] = result - self._put_raw_object(obj) - - def _put_rpc_error(self, req_id, code, message, data=None): - obj = {} - obj["id"] = req_id - obj["jsonrpc"] = "2.0" - obj["error"] = {} - obj["error"]["code"] = code - obj["error"]["message"] = message - obj["error"]["data"] = data - self._put_raw_object(obj) + self._put_raw_object(data) + + def _put_rpc_response(self, req_id: int, result: bytes): + if result is None: + sz = 0 + else: + sz = len(result) + builder = fbs.Builder(sz + 512) + if result is not None: + result_off = builder.CreateByteVector(result) + + if req_id < 0: + raise ValueError("Invalid request id") + + TransportResponse.TransportResponseStart(builder) + TransportResponse.AddId(builder, req_id) + if result is not None: + TransportResponse.AddResponse(builder, result_off) + end = TransportResponse.End(builder) + + TransportMessageRaw.TransportMessageRawStart(builder) + TransportMessageRaw.AddDataType( + builder, TransportMessageRaw_Data.TransportMessageRaw_Data.TransportResponse + ) + TransportMessageRaw.AddData(builder, end) + end2 = TransportMessageRaw.End(builder) + builder.Finish(end2) + self._put_raw_object(builder.Output()) + + def _put_rpc_error(self, req_id: int, code: int, message, data=None): + builder = fbs.Builder(512 + len(message) + len(data)) + message_off = builder.CreateString(message) + if data is not None: + data_off = builder.CreateString(data) + else: + data_off = 0 + + if req_id < 0: + raise ValueError("Invalid request id") + + TransportResponse.TransportResponseStart(builder) + TransportResponse.AddId(builder, req_id) + TransportResponse.AddCode(builder, code) + TransportResponse.AddMessage(builder, message_off) + if data_off != 0: + TransportResponse.AddResponse(builder, data_off) + end = TransportResponse.End(builder) + + TransportMessageRaw.TransportMessageRawStart(builder) + TransportMessageRaw.AddDataType( + builder, TransportMessageRaw_Data.TransportMessageRaw_Data.TransportResponse + ) + TransportMessageRaw.AddData(builder, end) + end2 = TransportMessageRaw.End(builder) + builder.Finish(end2) + self._put_raw_object(builder.Output()) def _send_thread(self): """ send the data to the socket """ + def send_remaining_data(): while not self._send_queue.empty(): - strdata = self._send_queue.get() - data = strdata.encode("utf-8") + data = self._send_queue.get() + # data = strdata.encode("utf-8") length = len(data) lendata = length.to_bytes(8, byteorder="little") self._client.sendall(lendata) - logger.debug("Sending Data: %s", data) self._client.sendall(data) def send_data(): # clear the pipe os.read(self._send_task_pipe_r, 1) - strdata = self._send_queue.get() - data = strdata.encode("utf-8") + data = self._send_queue.get() + # data = strdata.encode("utf-8") # get the length of utf8 string length = len(data) lendata = length.to_bytes(8, byteorder="little") # send the length of the data self._client.sendall(lendata) - logger.debug("Sending Data: %s", data) self._client.sendall(data) sel = selectors.DefaultSelector() @@ -416,22 +386,25 @@ def _recv_thread(self): get the data from socket and parse it """ - def recv_data(): + def recv_data() -> bool: # read int64 from the socket and convert to integer data = self._client.recv(8) + if len(data) == 0: + return False length = int.from_bytes(data, byteorder="little") - logger.debug("Received Length: %d", length) - # read the json data + # read the data data = b"" while len(data) < length: try: tmp = self._client.recv(length - len(data)) + if len(tmp) == 0: + return False data += tmp except BlockingIOError as e: if e.errno == 11: continue - obj = json.loads(data.decode("utf-8")) - self._recv_queue.put(obj) + self._recv_queue.put(data) + return True sel = selectors.DefaultSelector() sel.register(self._client, selectors.EVENT_READ) @@ -442,7 +415,9 @@ def recv_data(): if key.fileobj == self._stop_event_r: return if key.fileobj == self._client: - recv_data() + if not recv_data(): + logger.info("Connection closed") + return def stop(self): """ diff --git a/sdk/python/spear/transform/__init__.py b/sdk/python/spear/transform/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sdk/python/spear/transform/chat.py b/sdk/python/spear/transform/chat.py new file mode 100644 index 0000000..7e8247c --- /dev/null +++ b/sdk/python/spear/transform/chat.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +import argparse +import base64 +import logging +import sys + +import flatbuffers as fbs +import spear.client as client +import spear.hostcalls.transform as tf + +from spear.proto.chat import (ChatCompletionRequest, ChatCompletionResponse, + ChatMessage, ChatMetadata, Role) +from spear.proto.chat import ToolInfo as ChatToolInfo +from spear.proto.tool import BuiltinToolInfo, ToolInfo +from spear.proto.transform import (TransformOperation, TransformRequest, + TransformRequest_Params, TransformResponse, + TransformResponse_Data, TransformType) +from spear.proto.transport import Method + +logging.basicConfig( + level=logging.DEBUG, # Set the desired logging level + # Customize the log format + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(stream=sys.stderr)], # Log to stderr +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +DEFAULT_LLM_MODEL = "llama" # "gpt-4o" + + +def chat(agent: client.HostAgent, message: str, + model: str = DEFAULT_LLM_MODEL, + builtin_tools: list[int] = []): + """ + handle the llm request + """ + builder = fbs.Builder(len(message) + 2048) + content_off = builder.CreateString(message) + model_off = builder.CreateString(model) + + tools_off = -1 + if len(builtin_tools) > 0: + builtin_tool_offs = [] + for tool in builtin_tools: + assert isinstance(tool, int) + BuiltinToolInfo.BuiltinToolInfoStart(builder) + BuiltinToolInfo.AddToolId(builder, tool) + tmp = BuiltinToolInfo.End(builder) + ChatToolInfo.ToolInfoStart(builder) + ChatToolInfo.ToolInfoAddData(builder, tmp) + ChatToolInfo.AddDataType( + builder, + ToolInfo.ToolInfo.BuiltinToolInfo, + ) + builtin_tool_offs.append(ChatToolInfo.End(builder)) + # create vector + ChatCompletionRequest.StartToolsVector(builder, len(builtin_tool_offs)) + for off in builtin_tool_offs: + builder.PrependUOffsetTRelative(off) + tools_off = builder.EndVector() + + ChatMetadata.ChatMetadataStart(builder) + ChatMetadata.AddRole(builder, Role.Role.User) + metadata_off = ChatMetadata.End(builder) + + ChatMessage.ChatMessageStart(builder) + ChatMessage.AddContent(builder, content_off) + ChatMessage.AddMetadata(builder, metadata_off) + msg_off = ChatMessage.End(builder) + + ChatCompletionRequest.StartMessagesVector(builder, 1) + builder.PrependUOffsetTRelative(msg_off) + msglist_off = builder.EndVector() + + ChatCompletionRequest.ChatCompletionRequestStart(builder) + ChatCompletionRequest.AddMessages(builder, msglist_off) + ChatCompletionRequest.AddModel(builder, model_off) + if tools_off != -1: + ChatCompletionRequest.AddTools(builder, tools_off) + chatcomp_off = ChatCompletionRequest.End(builder) + + TransformRequest.StartInputTypesVector(builder, 1) + builder.PrependInt32(TransformType.TransformType.Text) + input_types_off = builder.EndVector() + + TransformRequest.StartOutputTypesVector(builder, 1) + builder.PrependInt32(TransformType.TransformType.Text) + output_types_off = builder.EndVector() + + TransformRequest.StartOperationsVector(builder, 1) + builder.PrependInt32(TransformOperation.TransformOperation.LLM) + if len(builtin_tools) > 0: + builder.PrependInt32(TransformOperation.TransformOperation.Tools) + operations_off = builder.EndVector() + + TransformRequest.TransformRequestStart(builder) + TransformRequest.AddInputTypes(builder, input_types_off) + TransformRequest.AddOutputTypes(builder, output_types_off) + TransformRequest.AddOperations(builder, operations_off) + TransformRequest.AddParams(builder, chatcomp_off) + TransformRequest.AddParamsType( + builder, + TransformRequest_Params.TransformRequest_Params.spear_proto_chat_ChatCompletionRequest, + ) + builder.Finish(TransformRequest.End(builder)) + + data = agent.exec_request(Method.Method.Transform, builder.Output()) + + resp = TransformResponse.TransformResponse.GetRootAsTransformResponse( + data, 0) + if ( + resp.DataType() + != TransformResponse_Data.TransformResponse_Data.spear_proto_chat_ChatCompletionResponse + ): + raise ValueError("Unexpected response data type") + + chat_resp = ChatCompletionResponse.ChatCompletionResponse() + chat_resp.Init(resp.Data().Bytes, resp.Data().Pos) + + if chat_resp.Code() != 0: + raise ValueError(chat_resp.Error()) + + msg_len = chat_resp.MessagesLength() + res = [] + for i in range(msg_len): + res.append(chat_resp.Messages(i).Content().decode("utf-8")) + + return res diff --git a/sdk/python/spear/utils/io.py b/sdk/python/spear/utils/io.py index 344835b..03e21d7 100644 --- a/sdk/python/spear/utils/io.py +++ b/sdk/python/spear/utils/io.py @@ -1,48 +1,94 @@ #!/usr/bin/env python3 import logging +import flatbuffers as fbs import spear.client as client +from spear.proto.io import (InputRequest, InputResponse, RecordRequest, + RecordResponse, SpeakRequest, SpeakResponse) +from spear.proto.transport import Method + logger = logging.getLogger(__name__) -def input(agent: client.HostAgent, prompt: str) -> str: + +def input(agent: client.HostAgent, prompt: str, dryrun: bool = False) -> str: """ get user input """ - user_input = agent.exec_request( - "input", - prompt, + builder = fbs.Builder(len(prompt) + 32) + prompt_off = builder.CreateString(prompt) + InputRequest.InputRequestStart(builder) + InputRequest.InputRequestAddPrompt(builder, prompt_off) + InputRequest.AddDryrun(builder, dryrun) + data_off = InputRequest.InputRequestEnd(builder) + builder.Finish(data_off) + + data = agent.exec_request( + Method.Method.Input, + builder.Output(), ) - if isinstance(user_input, client.JsonRpcOkResp): - user_input = user_input.result - else: - raise ValueError("Error getting user input") - return user_input + + resp = InputResponse.InputResponse.GetRootAsInputResponse(data, 0) + return resp.Text() -def speak(agent: client.HostAgent, data) -> str: +def speak( + agent: client.HostAgent, + data: str, + model: str = None, + voice: str = None, + fmt: str = None, +) -> bytes: """ get user input """ + builder = fbs.Builder(len(data) + 32) + data_off = builder.CreateString(data) + if model: + model_off = builder.CreateString(model) + if voice: + voice_off = builder.CreateString(voice) + if fmt: + fmt_off = builder.CreateString(format) + SpeakRequest.SpeakRequestStart(builder) + SpeakRequest.SpeakRequestAddText(builder, data_off) + if model: + SpeakRequest.SpeakRequestAddModel(builder, model_off) + if voice: + SpeakRequest.SpeakRequestAddVoice(builder, voice_off) + if fmt: + SpeakRequest.SpeakRequestAddFormat(builder, fmt_off) + + data_off = SpeakRequest.SpeakRequestEnd(builder) + builder.Finish(data_off) res = agent.exec_request( - "speak", - data, + Method.Method.Speak, + builder.Output(), ) - if isinstance(res, client.JsonRpcOkResp): - return - else: - raise ValueError("Error speaking") + resp = SpeakResponse.SpeakResponse.GetRootAsSpeakResponse(res, 0) + return resp.Data() + -def record(agent: client.HostAgent, prompt: str) -> str: +def record(agent: client.HostAgent, prompt: str, + model: str = "whisper-1") -> str: """ get user input """ + builder = fbs.Builder(len(prompt) + 32) + prompt_off = builder.CreateString(prompt) + if model: + model_off = builder.CreateString(model) + + RecordRequest.RecordRequestStart(builder) + RecordRequest.RecordRequestAddPrompt(builder, prompt_off) + if model: + RecordRequest.RecordRequestAddModel(builder, model_off) + data_off = RecordRequest.RecordRequestEnd(builder) + builder.Finish(data_off) res = agent.exec_request( - "record", - prompt, + Method.Method.Record, + builder.Output(), ) - if isinstance(res, client.JsonRpcOkResp): - res = res.result - else: - raise ValueError("Error recording") - return res + + resp = RecordResponse.RecordResponse.GetRootAsRecordResponse(res, 0) + return resp.Text() diff --git a/sdk/python/spear/utils/tool.py b/sdk/python/spear/utils/tool.py new file mode 100644 index 0000000..a420797 --- /dev/null +++ b/sdk/python/spear/utils/tool.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +import logging +import inspect + +import flatbuffers as fbs +import spear.client as client + +from spear.proto.tool import ( + InternalToolCreateRequest, InternalToolCreateResponse, InternalToolCreateParamSpec) +from spear.proto.transport import Method + +logger = logging.getLogger(__name__) + + +def register_internal_tool(agent: client.HostAgent, cb: callable, + name: str = None, desc: str = None) -> int: + """ + register internal tool + """ + builder = fbs.Builder(32) + + if name is None: + name = cb.__name__ + tool_name_off = builder.CreateString(name) + + if desc is None: + desc = inspect.getdoc(cb) + tool_desc_off = builder.CreateString(desc) + # parse all parameters from desc in the form of + # @param name: description + # new line is not supported yet + param_desc = {} + logger.info("desc: %s", desc) + for line in desc.split("\n"): + # trim leading and trailing spaces and tabs + line = line.strip() + if not line: + continue + if not line.startswith("@param"): + continue + # split by : + name, desc = line[7:].split(":", 1) + param_desc[name.strip()] = builder.CreateString(desc.strip()) + logger.info("param_desc: %s", param_desc) + + # create all parameters info + sig = inspect.signature(cb) + params = [] + names = {} + types = {} + for p in sig.parameters: + names[p] = builder.CreateString(p) + if sig.parameters[p].annotation is inspect.Parameter.empty: + types[p] = builder.CreateString("string") + else: + types[p] = builder.CreateString(sig.parameters[p].annotation) + for p in sig.parameters: + InternalToolCreateParamSpec.InternalToolCreateParamSpecStart(builder) + InternalToolCreateParamSpec.InternalToolCreateParamSpecAddName( + builder, names[p]) + InternalToolCreateParamSpec.InternalToolCreateParamSpecAddType( + builder, types[p]) + if p in param_desc: + InternalToolCreateParamSpec.InternalToolCreateParamSpecAddDescription( + builder, param_desc[p]) + InternalToolCreateParamSpec.InternalToolCreateParamSpecAddRequired( + builder, sig.parameters[p].default is inspect.Parameter.empty) + params.append( + InternalToolCreateParamSpec.InternalToolCreateParamSpecEnd(builder)) + + InternalToolCreateRequest.InternalToolCreateRequestStartParamsVector( + builder, len(params)) + for p in reversed(params): + builder.PrependUOffsetTRelative(p) + params_off = builder.EndVector() + + InternalToolCreateRequest.InternalToolCreateRequestStart(builder) + InternalToolCreateRequest.InternalToolCreateRequestAddName( + builder, tool_name_off) + InternalToolCreateRequest.InternalToolCreateRequestAddDescription( + builder, tool_desc_off) + InternalToolCreateRequest.InternalToolCreateRequestAddParams( + builder, params_off) + data_off = InternalToolCreateRequest.InternalToolCreateRequestEnd(builder) + builder.Finish(data_off) + + data = agent.exec_request( + Method.Method.InternalToolCreate, + builder.Output(), + ) + + resp = InternalToolCreateResponse.InternalToolCreateResponse.\ + GetRootAsInternalToolCreateResponse( + data, 0) + return resp.ToolId() diff --git a/sdk/python/tests/__init__.py b/sdk/python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sdk/python/tests/proto/__init__.py b/sdk/python/tests/proto/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sdk/python/tests/proto/server.py b/sdk/python/tests/proto/server.py new file mode 100644 index 0000000..7e95914 --- /dev/null +++ b/sdk/python/tests/proto/server.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +import logging +import socket +import struct +import threading + +import flatbuffers as fbs + +from spear.proto.chat import (ChatCompletionRequest, ChatCompletionResponse, + ChatMessage, ChatMetadata) +from spear.proto.io import InputRequest, InputResponse +from spear.proto.transform import (TransformRequest, TransformRequest_Params, + TransformResponse, TransformResponse_Data) +from spear.proto.transport import (Method, TransportMessageRaw, + TransportMessageRaw_Data, TransportRequest, + TransportResponse) + +MAX_INFLIGHT_REQUESTS = 128 +DEFAULT_MESSAGE_SIZE = 4096 +TEST_SERVER_DEFAULT_PORT = 12345 +TEST_SERVER_DEFAULT_SECRET = 12345 + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class TestAgentServer: + """ + A test tcp server for testing the agent + """ + + def __init__( + self, + port: int = TEST_SERVER_DEFAULT_PORT, + secret: int = TEST_SERVER_DEFAULT_SECRET, + ): + self._server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._server.bind(("localhost", port)) + self._server.listen(5) + self._client = None + self._secret = secret + + def run(self): + """ + run the server + """ + while True: + client, _ = self._server.accept() + # get the secret + data = client.recv(8) + secret = struct.unpack("= 0; i-- { + content := builder.CreateString(msgList[i].Content) + + chat.ChatMetadataStart(builder) + chat.ChatMetadataAddRole(builder, + roleStrToRole(msgList[i].Metadata["role"].(string))) + if msgList[i].Metadata["reason"] != nil { + chat.ChatMetadataAddReason(builder, + reasonStrToReason(msgList[i].Metadata["reason"].(string))) + } + metaOff := chat.ChatMetadataEnd(builder) + + chat.ChatMessageStart(builder) + chat.ChatMessageAddContent(builder, content) + chat.ChatMessageAddMetadata(builder, metaOff) + off := chat.ChatMessageEnd(builder) + + msgOff[i] = off + } + + chat.ChatCompletionResponseStartMessagesVector(builder, len(msgList)) + for i := len(msgList) - 1; i >= 0; i-- { + builder.PrependUOffsetT(msgOff[i]) + } + msgs := builder.EndVector(len(msgList)) + + chat.ChatCompletionResponseStart(builder) + chat.ChatCompletionResponseAddMessages(builder, msgs) + res := chat.ChatCompletionResponseEnd(builder) + + transform.TransformResponseStart(builder) + transform.TransformResponseAddData(builder, res) + transform.TransformResponseAddDataType(builder, + transform.TransformResponse_Dataspear_proto_chat_ChatCompletionResponse) + builder.Finish(transform.TransformResponseEnd(builder)) + + return builder.FinishedBytes() +} + +func ChatCompletionWithTools(inv *hcommon.InvocationInfo, + args *transform.TransformRequest) ([]byte, error) { + return chatCompletion(inv, args, true) +} + +func ChatCompletionNoTools(inv *hcommon.InvocationInfo, + args *transform.TransformRequest) ([]byte, error) { + return chatCompletion(inv, args, false) +} + +func chatCompletion(inv *hcommon.InvocationInfo, args *transform.TransformRequest, + hasTool bool) ([]byte, error) { + // verify the type of args is ChatCompletionRequest + chatReq := chat.ChatCompletionRequest{} + if err := helper.UnwrapTransformRequest(&chatReq, args); err != nil { + return nil, err + } + + if !hasTool && chatReq.ToolsLength() > 0 { + log.Infof("Tools are not supported in this function") + return nil, fmt.Errorf("tools are not supported in this function") + } + + log.Infof("Using model %s", chatReq.Model()) + + msgList, err := innerChatCompletion(inv, &chatReq, hasTool) + if err != nil { + return nil, fmt.Errorf("error calling innerChatCompletionNoTools: %v", err) + } + + buf := chatMessageToTransformBuffer(msgList) + return buf, nil +} + +func innerChatCompletion(inv *hcommon.InvocationInfo, chatReq *chat.ChatCompletionRequest, + hasTool bool) ([]ChatMessage, error) { + mem := NewChatCompletionMemory() + for idx := range chatReq.MessagesLength() { + msg := chat.ChatMessage{} + if !chatReq.Messages(&msg, idx) { + return nil, fmt.Errorf("error getting message") + } + + meta := chat.ChatMetadata{} + msg.Metadata(&meta) + tmp := ChatMessage{ + Metadata: map[string]interface{}{ + "role": roleToRoleStr(meta.Role()), + }, + Content: string(msg.Content()), + } + mem.AddMessage(tmp) + } + + var respData *hcopenai.OpenAIChatCompletionResponse + var err error + var count int + for count = 0; count < chatInnerLoopMaxCount; count++ { + // create a new chat request + openAiChatReq2 := hcopenai.OpenAIChatCompletionRequest{ + Model: string(chatReq.Model()), + Messages: []hcopenai.OpenAIChatMessage{}, + } + // build the messages + for _, msg := range mem.GetMessages() { + tmp := hcopenai.OpenAIChatMessage{ + Content: msg.Content, + } + if msg.Metadata["role"] != nil { + tmp.Role = msg.Metadata["role"].(string) + } + if msg.Metadata["tool_call_id"] != nil { + tmp.ToolCallId = msg.Metadata["tool_call_id"].(string) + } + if msg.Metadata["tool_calls"] != nil { + tmp.ToolCalls = msg.Metadata["tool_calls"].([]hcopenai.OpenAIChatToolCall) + } + openAiChatReq2.Messages = append(openAiChatReq2.Messages, tmp) + } + if hasTool { + // build the tools + if chatReq.ToolsLength() > 0 { + tmp := false + openAiChatReq2.ParallelToolCalls = &tmp + } + for idx := range chatReq.ToolsLength() { + toolInfo := chat.ToolInfo{} + if !chatReq.Tools(&toolInfo, idx) { + return nil, fmt.Errorf("error getting tool info") + } + tbl := flatbuffers.Table{} + if !toolInfo.Data(&tbl) { + return nil, fmt.Errorf("error getting params") + } + switch toolInfo.DataType() { + case tool.ToolInfoBuiltinToolInfo: + toolInfo := tool.BuiltinToolInfo{} + toolInfo.Init(tbl.Bytes, tbl.Pos) + tool, ok := hcommon.GetBuiltinTool(hcommon.BuiltinToolID(toolInfo.ToolId())) + if !ok { + return nil, fmt.Errorf("builtin tool not found") + } + requiredParams := make([]string, 0) + t := hcopenai.OpenAIChatToolFunction{ + Type: "function", + Func: hcopenai.OpenAIChatToolFunctionSub{ + Name: fmt.Sprintf("B-%d", tool.Id), + Description: tool.Description, + Parameters: hcopenai.OpenAIChatToolParameter{ + Type: "object", + AdditionalProperties: false, + Properties: make(map[string]hcopenai.OpenAIChatToolParameterProperty), + }, + }, + } + for k, v := range tool.Params { + t.Func.Parameters.Properties[k] = hcopenai.OpenAIChatToolParameterProperty{ + Type: v.Ptype, + Description: v.Description, + } + if v.Required { + requiredParams = append(requiredParams, k) + } + } + t.Func.Parameters.Required = requiredParams + openAiChatReq2.Tools = append(openAiChatReq2.Tools, t) + case tool.ToolInfoInternalToolInfo: + toolInfo := tool.InternalToolInfo{} + toolInfo.Init(tbl.Bytes, tbl.Pos) + // TODO: implement this + panic("not implemented") + case tool.ToolInfoNormalToolInfo: + toolInfo := tool.NormalToolInfo{} + toolInfo.Init(tbl.Bytes, tbl.Pos) + // TODO: implement this + panic("not implemented") + default: + return nil, fmt.Errorf("unexpected tool info data type") + } + } + } + + ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeChatWithTools, + openAiChatReq2.Model) + if len(ep) == 0 { + return nil, fmt.Errorf("no endpoint found") + } + respData, err = hcopenai.OpenAIChatCompletion(ep[0], &openAiChatReq2) + if err != nil { + return nil, fmt.Errorf("error calling OpenAIChatCompletion: %v", err) + } + + if len(respData.Choices) == 0 { + return nil, fmt.Errorf("no choices found") + } + if len(respData.Choices) > 1 { + return nil, fmt.Errorf("multiple choices found") + } + + choice := respData.Choices[0] + if choice.Index != 0 { + return nil, fmt.Errorf("index mismatch") + } + + if choice.Reason == "stop" || choice.Reason == "length" { + mem.AddMessage(ChatMessage{ + Metadata: map[string]interface{}{ + "role": choice.Message.Role, + "reason": choice.Reason, + }, + Content: choice.Message.Content, + }) + // we have done with the chat + if count == chatInnerLoopMaxCount { + count = 0 + } + break + } else if choice.Reason == "tool_calls" || len(choice.Message.ToolCalls) > 0 { + // do tool calls + if !hasTool { + log.Errorf("Unexpected tool calls") + return nil, fmt.Errorf("unexpected tool calls") + } + if choice.Message.Content != "" { + return nil, fmt.Errorf("unexpected content") + } + + mem.AddMessage(ChatMessage{ + Metadata: map[string]interface{}{ + "role": choice.Message.Role, + "tool_calls": choice.Message.ToolCalls, + }, + Content: choice.Message.Content, + }) + + toolCalls := choice.Message.ToolCalls + for _, toolCall := range toolCalls { + argsStr := toolCall.Function.Arguments + // use json to unmarshal the arguments to interface{} + var args interface{} = nil + if argsStr != "" { + err := json.Unmarshal([]byte(argsStr), &args) + if err != nil { + return nil, fmt.Errorf("error unmarshalling tool call arguments: %v", + err) + } + } + // check the tool type here + // the tool name should be in the format of "B-" when it is a built-in tool + // the tool name should be in the format of "I-" when it is an internal tool + // the tool name should be in the format of "N-" when it is a normal tool + toolName := toolCall.Function.Name + if len(toolName) < 3 || toolName[1] != '-' { + return nil, fmt.Errorf("invalid tool name") + } + toolType := toolName[:1] + // convert string to uint16 using strconv + toolId, err := strconv.ParseUint(toolName[2:], 10, 16) + if err != nil { + return nil, fmt.Errorf("error parsing tool id: %v", err) + } + switch toolType { + case "B": + // it is a built-in tool + toolReg, ok := hcommon.GetBuiltinTool(hcommon.BuiltinToolID(toolId)) + if !ok { + return nil, fmt.Errorf("builtin tool not found") + } + fn := toolReg.CbBuiltIn + if fn == nil { + return nil, fmt.Errorf("built-in tool not implemented") + } + res, err := fn(inv, args) + if err != nil { + return nil, fmt.Errorf("error calling built-in tool %s: %v", + toolReg.Name, err) + } + + tmp := fmt.Sprintf("%v", res) + if len(tmp) > 512 { + tmp = tmp[:509] + "..." + } + log.Infof("Builtin Tool call response: %v", tmp) + + mem.AddMessage(ChatMessage{ + Metadata: map[string]interface{}{ + "role": "tool", + "tool_call_id": toolCall.Id, + }, + Content: fmt.Sprintf("%v", res), + }) + case "I": + // it is an internal tool + case "N": + // it is a normal tool + default: + return nil, fmt.Errorf("invalid tool type") + } + } + } else { + return nil, fmt.Errorf("unexpected reason: %s", choice.Reason) + } + } + + if count == chatInnerLoopMaxCount { + return nil, fmt.Errorf("max count reached") + } + + return mem.GetMessages(), nil +} diff --git a/spearlet/hostcalls/common/common.go b/spearlet/hostcalls/common/common.go new file mode 100644 index 0000000..9a29a69 --- /dev/null +++ b/spearlet/hostcalls/common/common.go @@ -0,0 +1,321 @@ +package common + +import ( + "fmt" + "sync" + "time" + + flatbuffers "github.com/google/flatbuffers/go" + "github.com/lfedgeai/spear/pkg/spear/proto/transport" + "github.com/lfedgeai/spear/pkg/utils/protohelper" + "github.com/lfedgeai/spear/spearlet/task" + log "github.com/sirupsen/logrus" +) + +type HostCall struct { + NameID transport.Method + Handler HostCallHandler +} + +// invokation info +type InvocationInfo struct { + Task task.Task + CommMgr *CommunicationManager +} + +type RespChanData struct { + Resp *transport.TransportResponse + InvInfo *InvocationInfo +} + +type ReqChanData struct { + Req *transport.TransportRequest + InvInfo *InvocationInfo +} + +// communication manager for hostcalls and guest responses +type CommunicationManager struct { + respCh chan *RespChanData // incoming responses + reqCh chan *ReqChanData // incoming requests + outCh map[task.Task]chan task.Message + + pendingRequests map[int64]*requestCallback + pendingRequestsMu sync.RWMutex +} + +type HostCallHandler func(inv *InvocationInfo, args []byte) ([]byte, error) + +type HostCalls struct { + // map of hostcalls + HCMap map[transport.Method]HostCallHandler + CommMgr *CommunicationManager +} + +var ResponseTimeout = 5 * time.Minute + +func NewHostCalls(commMgr *CommunicationManager) *HostCalls { + return &HostCalls{ + HCMap: make(map[transport.Method]HostCallHandler), + CommMgr: commMgr, + } +} + +func (h *HostCalls) RegisterHostCall(hc *HostCall) error { + nameId := hc.NameID + handler := hc.Handler + log.Debugf("Registering hostcall: %v", nameId) + if _, ok := h.HCMap[nameId]; ok { + return fmt.Errorf("hostcall already registered: %v", nameId) + } + h.HCMap[nameId] = handler + return nil +} + +func (h *HostCalls) Run() { + for { + entry := h.CommMgr.GetIncomingRequest() + req := entry.Req + inv := entry.InvInfo + if handler, ok := h.HCMap[req.Method()]; ok { + result, err := handler(inv, req.RequestBytes()) + if err != nil { + log.Errorf("Error executing hostcall: %v", err) + if err := h.CommMgr.SendOutgoingRPCResponseError(inv.Task, req.Id(), -1, + err.Error()); err != nil { + log.Errorf("Error sending response: %v", err) + } + } else { + // send success response + log.Infof("Hostcall success: %v, ID %d", req.Method(), req.Id()) + if err := h.CommMgr.SendOutgoingRPCResponse(inv.Task, req.Id(), + result); err != nil { + log.Errorf("Error sending response: %v", err) + } + } + } else { + log.Errorf("Hostcall not found: %v", req.Method()) + if err := h.CommMgr.SendOutgoingRPCResponseError(inv.Task, req.Id(), 2, + "method not found"); err != nil { + log.Errorf("Error sending response: %v", err) + } + } + } +} + +func NewCommunicationManager() *CommunicationManager { + return &CommunicationManager{ + respCh: make(chan *RespChanData, 1024), + reqCh: make(chan *ReqChanData, 1024), + outCh: make(map[task.Task]chan task.Message), + + pendingRequests: make(map[int64]*requestCallback), + pendingRequestsMu: sync.RWMutex{}, + } +} + +func (c *CommunicationManager) InstallToTask(t task.Task) error { + if t == nil { + log.Errorf("task is nil") + return fmt.Errorf("task is nil") + } + + // check in and out channel + in, out, err := t.CommChannels() + if err != nil { + log.Errorf("Error getting communication channels: %v", err) + return err + } + + c.outCh[t] = in + + go func() { + inv := InvocationInfo{ + Task: t, + CommMgr: c, + } + + for msg := range out { + // process message + transRaw := transport.GetRootAsTransportMessageRaw(msg, 0) + if transRaw == nil { + log.Errorf("Error getting transport message raw") + continue + } + if transRaw.DataType() == transport.TransportMessageRaw_DataTransportRequest { + // request + req := transport.TransportRequest{} + // convert to transport request + reqTbl := &flatbuffers.Table{} + if !transRaw.Data(reqTbl) { + log.Errorf("Error getting transport request table") + continue + } + req.Init(reqTbl.Bytes, reqTbl.Pos) + log.Debugf("Hostcall received request: %d", req.Method()) + c.reqCh <- &ReqChanData{ + Req: &req, + InvInfo: &inv, + } + } else if transRaw.DataType() == transport.TransportMessageRaw_DataTransportResponse { + // response + resp := transport.TransportResponse{} + // convert to transport response + respTbl := &flatbuffers.Table{} + if !transRaw.Data(respTbl) { + log.Errorf("Error getting transport response table") + continue + } + resp.Init(respTbl.Bytes, respTbl.Pos) + log.Debugf("Hostcall received response: %d", resp.Id()) + go func() { + // check if it is response to a pending request + c.pendingRequestsMu.RLock() + entry, ok := c.pendingRequests[resp.Id()] + c.pendingRequestsMu.RUnlock() + if ok { + cb := entry.cb + if err := cb(&resp); err != nil { + log.Errorf("Error handling response: %v", err) + } + if entry.autoClear { + c.pendingRequestsMu.Lock() + delete(c.pendingRequests, resp.Id()) + c.pendingRequestsMu.Unlock() + } + return + } + + // this is when we receive a response that is not a pending request + c.respCh <- &RespChanData{ + Resp: &resp, + InvInfo: &inv, + } + }() + + } else if transRaw.DataType() == transport.TransportMessageRaw_DataTransportSignal { + sig := transport.TransportSignal{} + sigTbl := &flatbuffers.Table{} + if !transRaw.Data(sigTbl) { + log.Errorf("Error getting transport signal table") + continue + } + sig.Init(sigTbl.Bytes, sigTbl.Pos) + log.Debugf("Platform received signal: %s", sig.Method().String()) + } else { + log.Errorf("Invalid transport message type: %d", transRaw.DataType()) + } + } + }() + + return nil +} + +func (c *CommunicationManager) GetIncomingRequest() *ReqChanData { + return <-c.reqCh +} + +func (c *CommunicationManager) GetIncomingResponse() *RespChanData { + return <-c.respCh +} + +func (c *CommunicationManager) SendOutgoingRPCResponseError(t task.Task, id int64, code int, + msg string) error { + resp := protohelper.CreateErrorTransportResponse(id, code, msg) + if resp == nil { + return fmt.Errorf("error creating response") + } + data, err := protohelper.TransportResponseToRaw(resp) + if err != nil { + return err + } + c.outCh[t] <- data + return nil +} + +func (c *CommunicationManager) SendOutgoingRPCResponse(t task.Task, id int64, + result []byte) error { + raw, err := protohelper.RPCBufferResponseToRaw(id, result) + if err != nil { + return err + } + + c.outCh[t] <- raw + return nil +} + +type ResquestCallback func(resp *transport.TransportResponse) error + +type requestCallback struct { + cb ResquestCallback + autoClear bool + ts time.Time +} + +func (c *CommunicationManager) SendOutgoingRPCSignal(t task.Task, signal transport.Signal, + data []byte) error { + data, err := protohelper.RPCSignalToRaw(signal, data) + if err != nil { + return err + } + + c.outCh[t] <- data + return nil +} + +// req_buffer is +func (c *CommunicationManager) SendOutgoingRPCRequestCallback(t task.Task, id int64, + method transport.Method, + req_buffer []byte, cb func(*transport.TransportResponse) error) error { + if len(req_buffer) == 0 { + return fmt.Errorf("request is nil") + } + + data, err := protohelper.RPCBufferResquestToRaw(id, method, req_buffer) + if err != nil { + return err + } + + c.outCh[t] <- data + c.pendingRequestsMu.Lock() + c.pendingRequests[id] = &requestCallback{ + cb: cb, + autoClear: true, + ts: time.Now(), + } + c.pendingRequestsMu.Unlock() + return nil +} + +// users need to specify the id in the request +// req_buffer is the serialized transport.TransportRequest +func (c *CommunicationManager) SendOutgoingRPCRequest(t task.Task, method transport.Method, + req_buffer []byte) (*transport.TransportResponse, error) { + ch := make(chan *transport.TransportResponse, 1) + errCh := make(chan error, 1) + + req := transport.GetRootAsTransportRequest(req_buffer, 0) + if req == nil { + return nil, fmt.Errorf("error getting transport request") + } + + if err := c.SendOutgoingRPCRequestCallback(t, int64(t.NextRequestID()), method, req_buffer, + func(resp *transport.TransportResponse) error { + if resp.Code() != 0 { + errCh <- fmt.Errorf("error response: %v", resp.Code()) + } else { + ch <- resp + } + return nil + }); err != nil { + return nil, err + } + + select { + case resp := <-ch: + return resp, nil + case err := <-errCh: + return nil, err + case <-time.After(ResponseTimeout): + return nil, fmt.Errorf("timeout") + } +} diff --git a/worker/hostcalls/common/models.go b/spearlet/hostcalls/common/models.go similarity index 87% rename from worker/hostcalls/common/models.go rename to spearlet/hostcalls/common/models.go index 0ee588e..47a1cb0 100644 --- a/worker/hostcalls/common/models.go +++ b/spearlet/hostcalls/common/models.go @@ -12,7 +12,7 @@ type OpenAIFunctionType int type APIEndpointInfo struct { Name string Model string - Base string + Base *string APIKey string Url string } @@ -26,7 +26,7 @@ const ( OpenAIFunctionTypeImageGeneration ) -const ( +var ( OpenAIBase = "https://api.chatanywhere.tech/v1" GaiaToolLlamaGroqBase = "https://llamatool.us.gaianet.network/v1" GaiaToolLlama70BBase = "https://llama70b.gaia.domains/v1" @@ -43,42 +43,42 @@ var ( { Name: "openai-toolchat", Model: "gpt-4o", - Base: OpenAIBase, + Base: &OpenAIBase, APIKey: os.Getenv("OPENAI_API_KEY"), Url: "/chat/completions", }, { Name: "qwen-toolchat-72b", Model: "qwen", - Base: GaiaToolQWen72BBase, + Base: &GaiaToolQWen72BBase, APIKey: "gaia", Url: "/chat/completions", }, { Name: "llama-toolchat-70b", Model: "llama", - Base: GaiaToolLlama70BBase, + Base: &GaiaToolLlama70BBase, APIKey: "gaia", Url: "/chat/completions", }, { Name: "llama-toolchat", Model: "llama", - Base: GaiaToolLlamaGroqBase, + Base: &GaiaToolLlamaGroqBase, APIKey: "gaia", Url: "/chat/completions", }, { Name: "llama-toolchat-8b", Model: "llama", - Base: GaiaToolLlama8BBase, + Base: &GaiaToolLlama8BBase, APIKey: "gaia", Url: "/chat/completions", }, { Name: "llama-toolchat-3b", Model: "llama", - Base: GaiaToolLlama3BBase, + Base: &GaiaToolLlama3BBase, APIKey: "gaia", Url: "/chat/completions", }, @@ -87,13 +87,13 @@ var ( { Name: "openai-chat", Model: "gpt-4o", - Base: OpenAIBase, + Base: &OpenAIBase, APIKey: os.Getenv("OPENAI_API_KEY"), Url: "/chat/completions"}, { Name: "llama-chat", Model: "llama", - Base: GaiaToolLlama8BBase, + Base: &GaiaToolLlama8BBase, APIKey: "gaia", Url: "/chat/completions", }, @@ -102,14 +102,14 @@ var ( { Name: "openai-embed", Model: "text-embedding-ada-002", - Base: OpenAIBase, + Base: &OpenAIBase, APIKey: os.Getenv("OPENAI_API_KEY"), Url: "/embeddings", }, { Name: "nomic-embed", Model: "nomic-embed", - Base: GaiaToolLlama8BBase, + Base: &GaiaToolLlama8BBase, APIKey: "gaia", Url: "/embeddings", }, @@ -118,7 +118,7 @@ var ( { Name: "openai-tts", Model: "tts-1", - Base: OpenAIBase, + Base: &OpenAIBase, APIKey: os.Getenv("OPENAI_API_KEY"), Url: "/audio/speech", }, @@ -127,7 +127,7 @@ var ( { Name: "openai-genimage", Model: "dall-e-3", - Base: OpenAIBase, + Base: &OpenAIBase, APIKey: os.Getenv("OPENAI_API_KEY"), Url: "/images/generations", }, @@ -136,14 +136,14 @@ var ( { Name: "gaia-whisper", Model: "whisper", - Base: GaiaWhisperBase, + Base: &GaiaWhisperBase, APIKey: "gaia", Url: "/audio/transcriptions", }, { Name: "openai-whisper", Model: "whisper-1", - Base: OpenAIBase, + Base: &OpenAIBase, APIKey: os.Getenv("OPENAI_API_KEY"), Url: "/audio/transcriptions", }, @@ -172,7 +172,7 @@ func GetAPIEndpointInfo(ft OpenAIFunctionType, modelOrName string) []APIEndpoint } tmpList = append(tmpList, *tmp) } - log.Infof("Found %d endpoints for %s: %v", len(tmpList), modelOrName, tmpList) + log.Infof("Found %d endpoint(s) for %s: %v", len(tmpList), modelOrName, tmpList) return res } @@ -191,5 +191,7 @@ func init() { newAPIEndpointMap[ft] = newInfoList } APIEndpointMap = newAPIEndpointMap + } else { + OpenAIBase = "https://api.openai.com/v1" } } diff --git a/spearlet/hostcalls/common/tools.go b/spearlet/hostcalls/common/tools.go new file mode 100644 index 0000000..15e3bc2 --- /dev/null +++ b/spearlet/hostcalls/common/tools.go @@ -0,0 +1,133 @@ +package common + +import ( + "fmt" + + "github.com/lfedgeai/spear/pkg/spear/proto/tool" + "github.com/lfedgeai/spear/spearlet/task" +) + +type BuiltInToolCbFunc func(inv *InvocationInfo, args interface{}) (interface{}, error) + +type ToolParam struct { + Ptype string + Description string + Required bool +} + +type ToolId uint16 +type BuiltinToolID ToolId +type InternalToolID ToolId + +type ToolRegistry struct { + ToolType ToolType + Name string + Id BuiltinToolID + Description string + Params map[string]ToolParam + CbIdInternal InternalToolID + CbBuiltIn BuiltInToolCbFunc +} + +type ToolType int + +const ( + ToolType_Invalid ToolType = iota + ToolType_Internal + ToolType_Builtin + ToolType_Normal +) + +const ( + BuiltinToolID_Invalid = BuiltinToolID(tool.BuiltinToolIDInvalid) + BuiltinToolID_Datetime = BuiltinToolID(tool.BuiltinToolIDDatetime) + BuiltinToolID_Sleep = BuiltinToolID(tool.BuiltinToolIDSleep) + BuiltinToolID_SearchContactEmail = BuiltinToolID(tool.BuiltinToolIDSearchContactEmail) + // email tools + BuiltinToolID_ListOpenEmails = BuiltinToolID(tool.BuiltinToolIDListOpenEmails) + BuiltinToolID_ComposeEmail = BuiltinToolID(tool.BuiltinToolIDComposeEmail) + BuiltinToolID_SendEmailDraftWindow = BuiltinToolID(tool.BuiltinToolIDSendEmailDraftWindow) + // mouse tools + BuiltinToolID_MouseRightClick = BuiltinToolID(tool.BuiltinToolIDMouseRightClick) + BuiltinToolID_MouseLeftClick = BuiltinToolID(tool.BuiltinToolIDMouseLeftClick) + // phone tools + BuiltinToolID_PhoneCall = BuiltinToolID(tool.BuiltinToolIDPhoneCall) + // screen tools + BuiltinToolID_FullScreenshot = BuiltinToolID(tool.BuiltinToolIDFullScreenshot) + // web tools + BuiltinToolID_OpenURL = BuiltinToolID(tool.BuiltinToolIDOpenURL) + BuiltinToolID_ScrollDown = BuiltinToolID(tool.BuiltinToolIDScrollDown) + BuiltinToolID_ScrollUp = BuiltinToolID(tool.BuiltinToolIDScrollUp) + BuiltinToolID_PageDown = BuiltinToolID(tool.BuiltinToolIDPageDown) + BuiltinToolID_PageUp = BuiltinToolID(tool.BuiltinToolIDPageUp) + BuiltinToolID_WebScreenshot = BuiltinToolID(tool.BuiltinToolIDWebScreenshot) + + BuiltinToolID_Max = BuiltinToolID(tool.BuiltinToolIDMax) +) + +var ( + taskInternalTools = map[task.TaskID][]ToolRegistry{} + builtinTools = map[BuiltinToolID]ToolRegistry{} +) + +// builtin tools + +func RegisterBuiltinTool(tool ToolRegistry) error { + if _, ok := builtinTools[tool.Id]; ok { + return fmt.Errorf("duplicate tool registration") + } + builtinTools[tool.Id] = tool + return nil +} + +func UnregisterBuiltinTool(id BuiltinToolID) { + delete(builtinTools, id) +} + +func GetBuiltinTool(id BuiltinToolID) (ToolRegistry, bool) { + tool, ok := builtinTools[id] + return tool, ok +} + +// task internal tools + +func RegisterTaskInternalTool(t task.Task, tool ToolRegistry) (InternalToolID, error) { + if tool.CbIdInternal != InternalToolID(0) { + return InternalToolID(0), + fmt.Errorf("the registered tool must not have a callback id set") + } + if _, ok := taskInternalTools[t.ID()]; !ok { + taskInternalTools[t.ID()] = []ToolRegistry{} + } + taskInternalTools[t.ID()] = append(taskInternalTools[t.ID()], tool) + newId := InternalToolID(len(taskInternalTools[t.ID()]) - 1) + taskInternalTools[t.ID()][newId].CbIdInternal = newId + return newId, nil +} + +func ClearTaskInternalTools(t task.Task) { + delete(taskInternalTools, t.ID()) +} + +func GetTaskInternalTool(t task.Task, tid BuiltinToolID) (ToolRegistry, bool) { + taskTool, ok := taskInternalTools[t.ID()] + if !ok { + return ToolRegistry{}, false + } + if tid >= BuiltinToolID(len(taskTool)) { + return ToolRegistry{}, false + } + return taskTool[tid], true +} + +func GetTaskInternalToolByName(t task.Task, name string) (ToolRegistry, bool) { + if _, ok := taskInternalTools[t.ID()]; !ok { + return ToolRegistry{}, false + } + for tid, tool := range taskInternalTools[t.ID()] { + if tool.Name == name { + return taskInternalTools[t.ID()][tid], true + } + } + return ToolRegistry{}, false +} diff --git a/spearlet/hostcalls/embeddings.go b/spearlet/hostcalls/embeddings.go new file mode 100644 index 0000000..86c76aa --- /dev/null +++ b/spearlet/hostcalls/embeddings.go @@ -0,0 +1,39 @@ +package hostcalls + +import ( + "fmt" + + hostcalls "github.com/lfedgeai/spear/spearlet/hostcalls/common" + "github.com/lfedgeai/spear/spearlet/hostcalls/huggingface" + openaihc "github.com/lfedgeai/spear/spearlet/hostcalls/openai" +) + +type EmbeddingFunc func(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) + +var ( + globalEmbeddings = map[string]EmbeddingFunc{ + "text-embedding-ada-002": openaihc.Embeddings, + "bge-large-en-v1.5": huggingface.Embeddings, + } +) + +func Embeddings(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + // embeddingsReq := transform.EmbeddingsRequest{} + // err = embeddingsReq.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // for k, v := range globalEmbeddings { + // if k == embeddingsReq.Model { + // return v(inv, args) + // } + // } + // return nil, fmt.Errorf("embedding not found") + + return nil, fmt.Errorf("not implemented") +} diff --git a/spearlet/hostcalls/gen_image.go b/spearlet/hostcalls/gen_image.go new file mode 100644 index 0000000..d036bc8 --- /dev/null +++ b/spearlet/hostcalls/gen_image.go @@ -0,0 +1,50 @@ +package hostcalls + +import ( + "fmt" + + hostcalls "github.com/lfedgeai/spear/spearlet/hostcalls/common" +) + +func TextToImage(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { + // // right now we just call openai TextToSpeech + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + + // req := &transform.ImageGenerationRequest{} + // err = req.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // req2 := &oai.OpenAIImageGenerationRequest{ + // Model: req.Model, + // Prompt: req.Prompt, + // ResponseFormat: req.ResponseFormat, + // } + // ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeImageGeneration, req2.Model) + // if len(ep) == 0 { + // return nil, fmt.Errorf("error getting endpoint for model %s", req2.Model) + // } + // res, err := oai.OpenAIImageGeneration(ep[0], req2) + // if err != nil { + // return nil, fmt.Errorf("error calling openai TextToImage: %v", err) + // } + + // res2 := &transform.ImageGenerationResponse{ + // Created: res.Created, + // } + // for _, obj := range res.Data { + // res2.Data = append(res2.Data, transform.ImageObject{ + // Url: obj.Url, + // B64Json: obj.B64Json, + // RevisedPrompt: obj.RevisedPrompt, + // }) + // } + + // return res2, nil + + return nil, fmt.Errorf("not implemented") +} diff --git a/spearlet/hostcalls/hc_entries.go b/spearlet/hostcalls/hc_entries.go new file mode 100644 index 0000000..c31124a --- /dev/null +++ b/spearlet/hostcalls/hc_entries.go @@ -0,0 +1,100 @@ +package hostcalls + +import ( + "github.com/lfedgeai/spear/pkg/spear/proto/transport" + hostcalls "github.com/lfedgeai/spear/spearlet/hostcalls/common" +) + +var Hostcalls = []*hostcalls.HostCall{ + { + NameID: transport.MethodTransform, + Handler: Transform, + }, + { + NameID: transport.MethodTransformConfig, + Handler: TransformConfig, + }, + // invoke tool + { + NameID: transport.MethodToolInvoke, + Handler: nil, + }, + { + NameID: transport.MethodInternalToolCreate, + Handler: NewInternalTool, + }, + // // chat operations + // { + // NameID: transform.HostCallChatCompletion, + // Handler: ChatCompletionWithTools, + // }, + // // text to speech operations + // { + // NameID: openai.HostCallTextToSpeech, + // Handler: openaihc.TextToSpeech, + // }, + // // image generation operations + // { + // NameID: openai.HostCallImageGeneration, + // Handler: openaihc.ImageGeneration, + // }, + // // embeddings operations + // { + // NameID: openai.HostCallEmbeddings, + // Handler: openaihc.Embeddings, + // }, + // vector store operations + { + NameID: transport.MethodVecStoreCreate, + Handler: VectorStoreCreate, + }, + { + NameID: transport.MethodVecStoreDelete, + Handler: VectorStoreDelete, + }, + { + NameID: transport.MethodVecStoreInsert, + Handler: VectorStoreInsert, + }, + { + NameID: transport.MethodVecStoreQuery, + Handler: VectorStoreSearch, + }, + // message passing operations + // { + // NameID: payload.HostCallMessagePassingRegister, + // Handler: MessagePassingRegister, + // }, + // { + // NameID: payload.HostCallMessagePassingUnregister, + // Handler: MessagePassingUnregister, + // }, + // { + // NameID: payload.HostCallMessagePassingLookup, + // Handler: MessagePassingLookup, + // }, + // { + // NameID: payload.HostCallMessagePassingSend, + // Handler: MessagePassingSend, + // }, + // input operations + { + NameID: transport.MethodInput, + Handler: Input, + }, + // speak operations + { + NameID: transport.MethodSpeak, + Handler: Speak, + }, + // record operations + { + NameID: transport.MethodRecord, + Handler: Record, + }, + // custom operations + { + NameID: transport.MethodCustom, + Handler: nil, + }, +} diff --git a/spearlet/hostcalls/huggingface/huggingface_hc.go b/spearlet/hostcalls/huggingface/huggingface_hc.go new file mode 100644 index 0000000..e00ca38 --- /dev/null +++ b/spearlet/hostcalls/huggingface/huggingface_hc.go @@ -0,0 +1,84 @@ +package huggingface + +import ( + "fmt" + + hostcalls "github.com/lfedgeai/spear/spearlet/hostcalls/common" +) + +type HuggingFaceEmbeddingsRequest struct { + Inputs string `json:"inputs"` +} + +func Embeddings(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { + // // verify the type of args is EmbeddingsRequest + // // use json marshal and unmarshal to verify the type + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + // embeddingsReq := transform.EmbeddingsRequest{} + // err = embeddingsReq.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // embeddingsReq2 := HuggingFaceEmbeddingsRequest{ + // Inputs: embeddingsReq.Input, + // } + + // jsonBytes, err = json.Marshal(embeddingsReq2) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + + // // make sure HUGGINGFACEHUB_API_TOKEN is there + // if os.Getenv("HUGGINGFACEHUB_API_TOKEN") == "" { + // return nil, fmt.Errorf("error getting huggingface api token") + // } + // apiKey := os.Getenv("HUGGINGFACEHUB_API_TOKEN") + + // log.Debugf("Embeddings Request: %s", string(jsonBytes)) + // res, err := net.SendRequest( + // "https://api-inference.huggingface.co/models/BAAI/bge-large-en-v1.5", + // bytes.NewBuffer(jsonBytes), + // net.ContentTypeJSON, + // apiKey, + // ) + + // if err != nil { + // return nil, fmt.Errorf("error sending request: %v", err) + // } + + // listRes := []float64{} + // if err := json.Unmarshal(res, &listRes); err != nil { + // // might be something like + // // {"error":"Model BAAI/bge-large-en-v1.5 is currently loading","estimated_time":53.62286376953125} + // tmp := map[string]interface{}{} + // if err := json.Unmarshal(res, &tmp); err != nil { + // log.Errorf("Error unmarshalling data: %v", res) + // return nil, fmt.Errorf("error unmarshalling data. %v", err) + // } + // if _, ok := tmp["error"]; ok { + // log.Warnf("Model is not ready yet: %v", tmp) + // listRes = []float64{1.1, 2.2, 3.3} + // } else { + // log.Errorf("Error unmarshalling data: %v", res) + // return nil, fmt.Errorf("error unmarshalling data. %v", err) + // } + // } + // respData := transform.EmbeddingsResponse{} + // respData.Data = []transform.EmbeddingObject{ + // { + // Object: "embedding", + // Embedding: listRes, + // Index: 0, + // }, + // } + // respData.Model = "bge-large-en-v1.5" + + // // return the response + // return respData, nil + + return nil, fmt.Errorf("not implemented") +} diff --git a/worker/hostcalls/local_utils.go b/spearlet/hostcalls/io.go similarity index 53% rename from worker/hostcalls/local_utils.go rename to spearlet/hostcalls/io.go index 9b41d04..4b7305d 100644 --- a/worker/hostcalls/local_utils.go +++ b/spearlet/hostcalls/io.go @@ -8,48 +8,90 @@ import ( "sync" "time" + flatbuffers "github.com/google/flatbuffers/go" + "github.com/lfedgeai/spear/pkg/io" + protoio "github.com/lfedgeai/spear/pkg/spear/proto/io" + "github.com/faiface/beep" "github.com/faiface/beep/mp3" "github.com/faiface/beep/speaker" - "github.com/lfedgeai/spear/pkg/io" - "github.com/lfedgeai/spear/pkg/utils" - hostcalls "github.com/lfedgeai/spear/worker/hostcalls/common" + hostcalls "github.com/lfedgeai/spear/spearlet/hostcalls/common" "github.com/schollz/progressbar/v3" + log "github.com/sirupsen/logrus" ) -func Input(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - // read from stdin - fmt.Print(args.(string)) +const ( + defaultTTSModel = "tts-1" + defaultTTSVoice = "nova" + defaultTTSFormat = "mp3" + defaultSTTModel = "whisper-1" +) + +func Input(inv *hostcalls.InvocationInfo, args []byte) ([]byte, error) { + req := protoio.GetRootAsInputRequest(args, 0) + if req == nil { + return nil, fmt.Errorf("could not get InputRequest") + } + + // display the prompt + fmt.Print(req.Prompt()) reader := bufio.NewReader(os.Stdout) // Read a line from stdout - line, err := reader.ReadString('\n') - if err != nil { - fmt.Println("Error reading from stdout:", err) - return nil, err + line := "" + var err error + if req.Dryrun() { + line = "test" + } else { + line, err = reader.ReadString('\n') + if err != nil { + fmt.Println("Error reading from stdout:", err) + return nil, err + } } - return line, nil + builder := flatbuffers.NewBuilder(0) + lineOff := builder.CreateString(line) + + protoio.InputResponseStart(builder) + protoio.InputResponseAddText(builder, lineOff) + builder.Finish(protoio.InputResponseEnd(builder)) + + return builder.FinishedBytes(), nil } -func Speak(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { +func Speak(inv *hostcalls.InvocationInfo, args []byte) ([]byte, error) { // speak the text - // umarshal the response + req := protoio.GetRootAsSpeakRequest(args, 0) + if req == nil { + return nil, fmt.Errorf("could not get SpeakRequest") + } - var data map[string]interface{} - err := utils.InterfaceToType(&data, args) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) + transcript := req.Text() + model := defaultTTSModel + voice := defaultTTSVoice + format := defaultTTSFormat + + if len(req.Model()) > 0 { + model = string(req.Model()) + } + if len(req.Voice()) > 0 { + voice = string(req.Voice()) + } + if len(req.Format()) > 0 { + format = string(req.Format()) } - // get the "audio" key from the response - encodedData, ok := data["audio"] - if !ok { - panic("audio key not found in response") + encodedData, err := textToSpeechData(string(transcript), model, voice, format) + if err != nil { + return nil, fmt.Errorf("error getting audio data: %w", err) } + + log.Debugf("Speaking: %s", transcript) + // convert from base64 to []byte - rawData, err := base64.StdEncoding.DecodeString(encodedData.(string)) + rawData, err := base64.StdEncoding.DecodeString(encodedData) if err != nil { panic("base64.StdEncoding.DecodeString failed: " + err.Error()) } @@ -59,7 +101,7 @@ func Speak(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) // write to a temp file f, err := os.CreateTemp("", "audio*.mp3") if err != nil { - return nil, fmt.Errorf("os.CreateTemp failed: " + err.Error()) + return nil, fmt.Errorf("os.CreateTemp failed: %s", err.Error()) } defer os.Remove(f.Name()) @@ -67,7 +109,7 @@ func Speak(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) // wrtie the audio data to the file _, err = f.Write(rawData) if err != nil { - return nil, fmt.Errorf("f.Write failed: " + err.Error()) + return nil, fmt.Errorf("f.Write failed: %s", err.Error()) } f.Close() log.Debugf("Created temp file: %s", f.Name()) @@ -77,7 +119,14 @@ func Speak(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) return nil, fmt.Errorf("could not play MP3 file: %w", err) } - return nil, nil + builder := flatbuffers.NewBuilder(0) + dataOff := builder.CreateString(encodedData) + + protoio.SpeakResponseStart(builder) + protoio.SpeakResponseAddData(builder, dataOff) + builder.Finish(protoio.SpeakResponseEnd(builder)) + + return builder.FinishedBytes(), nil } func playMP3(filePath string) error { @@ -136,9 +185,16 @@ func playMP3(filePath string) error { } } -func Record(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - // record audio - fmt.Println(args.(string)) +func Record(inv *hostcalls.InvocationInfo, args []byte) ([]byte, error) { + req := protoio.GetRootAsRecordRequest(args, 0) + if req == nil { + return nil, fmt.Errorf("could not get RecordRequest") + } + + model := defaultTTSModel + if req.Model() != nil { + model = string(req.Model()) + } wg := &sync.WaitGroup{} wg.Add(1) @@ -149,6 +205,7 @@ func Record(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error log.Errorf("Failed to record audio: %v", err) return } + wavData = data }) if err != nil { @@ -181,5 +238,18 @@ func Record(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error bar.Describe("Recorded") bar.Close() - return wavData, nil + // convert wavData to text + text, err := speechToTextString(wavData, model) + if err != nil { + return nil, fmt.Errorf("error converting audio to text: %w", err) + } + + builder := flatbuffers.NewBuilder(0) + textOff := builder.CreateString(text) + + protoio.RecordResponseStart(builder) + protoio.RecordResponseAddText(builder, textOff) + builder.Finish(protoio.RecordResponseEnd(builder)) + + return builder.FinishedBytes(), nil } diff --git a/spearlet/hostcalls/msgpassing.go b/spearlet/hostcalls/msgpassing.go new file mode 100644 index 0000000..6eb10ac --- /dev/null +++ b/spearlet/hostcalls/msgpassing.go @@ -0,0 +1,143 @@ +package hostcalls + +import ( + "fmt" + + hostcalls "github.com/lfedgeai/spear/spearlet/hostcalls/common" +) + +type MessagePassingRegistry struct { + name string + method string + pendingData chan interface{} + id uint64 +} + +var ( + globalRegisteredMessagePassing = map[string]MessagePassingRegistry{} +) + +func MessagePassingRegister(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { + // task := *(inv.Task) + // log.Debugf("Executing hostcall \"%s\" with args %v for task %s", + // payload.HostCallMessagePassingRegister, args, task.ID()) + + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + // req := payload.MessagePassingRegisterRequest{} + // err = req.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // globalRegisteredMessagePassing[req.Name] = MessagePassingRegistry{ + // name: req.Name, + // method: req.Method, + // pendingData: make(chan interface{}), + // id: rand.Uint64(), + // } + + // return &payload.MessagePassingRegisterResponse{ + // MsgPassingId: globalRegisteredMessagePassing[req.Name].id, + // }, nil + + return nil, fmt.Errorf("not implemented") +} + +func MessagePassingUnregister(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { + // task := *(inv.Task) + // log.Debugf("Executing hostcall \"%s\" with args %v for task %s", + // payload.HostCallMessagePassingUnregister, args, task.ID()) + + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + // req := payload.MessagePassingUnregisterRequest{} + // err = req.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // found := false + + // for k, v := range globalRegisteredMessagePassing { + // if v.id == req.MsgPassingId { + // delete(globalRegisteredMessagePassing, k) + // found = true + // break + // } + // } + + // if !found { + // return nil, fmt.Errorf("message passing id not found") + // } + + // return &payload.MessagePassingUnregisterResponse{ + // MsgPassingId: req.MsgPassingId, + // }, nil + + return nil, fmt.Errorf("not implemented") +} + +func MessagePassingLookup(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { + // task := *(inv.Task) + // log.Debugf("Executing hostcall \"%s\" with args %v for task %s", + // payload.HostCallMessagePassingLookup, args, task.ID()) + + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + // req := payload.MessagePassingLookupRequest{} + // err = req.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // if v, ok := globalRegisteredMessagePassing[req.Name]; ok { + // return &payload.MessagePassingLookupResponse{ + // MsgPassingId: v.id, + // }, nil + // } + + // return nil, fmt.Errorf("message passing name not found") + + return nil, fmt.Errorf("not implemented") +} + +func MessagePassingSend(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { + // task := *(inv.Task) + // log.Debugf("Executing hostcall \"%s\" with args %v for task %s", + // payload.HostCallMessagePassingSend, args, task.ID()) + + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + // req := payload.MessagePassingSendRequest{} + // err = req.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // for _, v := range globalRegisteredMessagePassing { + // if v.id == req.MsgPassingId { + // // send without blocking + // select { + // case v.pendingData <- req.Data: + // return &payload.MessagePassingSendResponse{ + // MsgPassingId: req.MsgPassingId, + // }, nil + // default: + // return nil, fmt.Errorf("message passing channel full") + // } + // } + // } + + // return nil, fmt.Errorf("message passing id not found") + + return nil, fmt.Errorf("not implemented") +} diff --git a/worker/hostcalls/openai/openai_hc.go b/spearlet/hostcalls/openai/openai_hc.go similarity index 82% rename from worker/hostcalls/openai/openai_hc.go rename to spearlet/hostcalls/openai/openai_hc.go index d7a84e2..575250f 100644 --- a/worker/hostcalls/openai/openai_hc.go +++ b/spearlet/hostcalls/openai/openai_hc.go @@ -9,10 +9,8 @@ import ( "mime/multipart" "github.com/lfedgeai/spear/pkg/net" - "github.com/lfedgeai/spear/pkg/rpc/payload/transform" - "github.com/lfedgeai/spear/pkg/utils" - "github.com/lfedgeai/spear/worker/hostcalls/common" - hcommon "github.com/lfedgeai/spear/worker/hostcalls/common" + "github.com/lfedgeai/spear/spearlet/hostcalls/common" + hcommon "github.com/lfedgeai/spear/spearlet/hostcalls/common" log "github.com/sirupsen/logrus" ) @@ -42,7 +40,7 @@ type OpenAIChatCompletionResponse struct { type OpenAIChatChoice struct { Message OpenAIChatMessage `json:"message"` - Index json.Number `json:"index"` + Index int64 `json:"index"` Reason string `json:"finish_reason"` } @@ -70,9 +68,10 @@ type OpenAIChatToolFunctionSub struct { } type OpenAIChatCompletionRequest struct { - Messages []OpenAIChatMessage `json:"messages"` - Model string `json:"model"` - Tools []OpenAIChatToolFunction `json:"tools"` + Messages []OpenAIChatMessage `json:"messages"` + Model string `json:"model"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + Tools []OpenAIChatToolFunction `json:"tools"` } type EndpointInfo struct { @@ -86,15 +85,16 @@ func OpenAIChatCompletion(ep common.APIEndpointInfo, chatReq *OpenAIChatCompleti return nil, fmt.Errorf("error marshalling OpenAIChatCompletionRequest: %v", err) } - // log.Debugf("Chat Request: %s", string(jsonBytes)) // create a https request to https:///chat/completions and use b as the request body - u := ep.Base + ep.Url + u := *ep.Base + ep.Url + log.Infof("URL: %s", u) + log.Infof("Request: %s", string(jsonBytes)) res, err := net.SendRequest(u, bytes.NewBuffer(jsonBytes), net.ContentTypeJSON, ep.APIKey) if err != nil { return nil, fmt.Errorf("error sending request: %v", err) } - // log.Debugf("OpenAI Response: %s", string(res)) + log.Infof("Response: %s", string(res)) respData := OpenAIChatCompletionResponse{} err = json.Unmarshal(res, &respData) if err != nil { @@ -136,43 +136,45 @@ type OpenAIEmbeddingsResponse struct { } func Embeddings(inv *hcommon.InvocationInfo, args interface{}) (interface{}, error) { - // verify the type of args is EmbeddingsRequest - // use json marshal and unmarshal to verify the type - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - embeddingsReq := transform.EmbeddingsRequest{} - err = embeddingsReq.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - req := OpenAIEmbeddingsRequest{ - Input: embeddingsReq.Input, - Model: embeddingsReq.Model, - } - - ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeEmbeddings, req.Model) - if len(ep) == 0 { - return nil, fmt.Errorf("error getting endpoint for model %s", req.Model) - } - resp, err := OpenAIEmbeddings(ep[0], &req) - if err != nil { - return nil, fmt.Errorf("error calling OpenAIEmbeddings: %v", err) - } - - resp2 := transform.EmbeddingsResponse{ - Object: resp.Object, - Model: resp.Model, - Usage: resp.Usage, - } - err = utils.InterfaceToType(&resp2.Data, resp.Data) - if err != nil { - return nil, fmt.Errorf("error converting response: %v", err) - } - - return resp2, nil + // // verify the type of args is EmbeddingsRequest + // // use json marshal and unmarshal to verify the type + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + // embeddingsReq := transform.EmbeddingsRequest{} + // err = embeddingsReq.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // req := OpenAIEmbeddingsRequest{ + // Input: embeddingsReq.Input, + // Model: embeddingsReq.Model, + // } + + // ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeEmbeddings, req.Model) + // if len(ep) == 0 { + // return nil, fmt.Errorf("error getting endpoint for model %s", req.Model) + // } + // resp, err := OpenAIEmbeddings(ep[0], &req) + // if err != nil { + // return nil, fmt.Errorf("error calling OpenAIEmbeddings: %v", err) + // } + + // resp2 := transform.EmbeddingsResponse{ + // Object: resp.Object, + // Model: resp.Model, + // Usage: resp.Usage, + // } + // err = utils.InterfaceToType(&resp2.Data, resp.Data) + // if err != nil { + // return nil, fmt.Errorf("error converting response: %v", err) + // } + + // return resp2, nil + + return nil, fmt.Errorf("not implemented") } func OpenAIEmbeddings(ep common.APIEndpointInfo, args *OpenAIEmbeddingsRequest) (*OpenAIEmbeddingsResponse, error) { @@ -185,7 +187,7 @@ func OpenAIEmbeddings(ep common.APIEndpointInfo, args *OpenAIEmbeddingsRequest) log.Debugf("Embeddings Request: %s", string(jsonBytes)) // create a https request to https://api.openai.com/v1/embeddings and use b as the request body - u := ep.Base + ep.Url + u := *ep.Base + ep.Url res, err := net.SendRequest(u, bytes.NewBuffer(jsonBytes), net.ContentTypeJSON, ep.APIKey) if err != nil { return nil, fmt.Errorf("error sending request: %v", err) @@ -222,7 +224,7 @@ func OpenAITextToSpeech(ep common.APIEndpointInfo, args *OpenAITextToSpeechReque } log.Debugf("TextToSpeech Request: %s", string(jsonBytes)) - u := ep.Base + ep.Url + u := *ep.Base + ep.Url res, err := net.SendRequest(u, bytes.NewBuffer(jsonBytes), net.ContentTypeJSON, ep.APIKey) if err != nil { return nil, fmt.Errorf("error sending request: %v", err) @@ -243,7 +245,7 @@ func OpenAITextToSpeech(ep common.APIEndpointInfo, args *OpenAITextToSpeechReque type OpenAISpeechToTextRequest struct { Model string `json:"model"` - Audio string `json:"audio"` + Audio []byte `json:"audio"` } type OpenAISpeechToTextResponse struct { @@ -265,7 +267,7 @@ func OpenAISpeechToText(ep common.APIEndpointInfo, args *OpenAISpeechToTextReque } log.Debugf("SpeechToText Request: %v", sttReq) - u := ep.Base + ep.Url + u := *ep.Base + ep.Url // send data as multipart/form-data payload := &bytes.Buffer{} @@ -276,12 +278,7 @@ func OpenAISpeechToText(ep common.APIEndpointInfo, args *OpenAISpeechToTextReque } log.Debugf("Audio data: %v", sttReq.Audio) // convert base64 encoded audio data to bytes - data := make([]byte, base64.StdEncoding.DecodedLen(len(sttReq.Audio))) - n, err := base64.StdEncoding.Decode(data, []byte(sttReq.Audio)) - if err != nil { - return nil, fmt.Errorf("error decoding audio data: %v", err) - } - log.Debugf("Decoded audio data len: %d", n) + data := sttReq.Audio _, err = io.Copy(part, bytes.NewReader(data)) if err != nil { return nil, fmt.Errorf("error copying audio data: %v", err) @@ -345,7 +342,7 @@ func OpenAIImageGeneration(ep common.APIEndpointInfo, args *OpenAIImageGeneratio } log.Debugf("ImageGeneration Request: %s", string(jsonBytes)) - u := ep.Base + ep.Url + u := *ep.Base + ep.Url res, err := net.SendRequest(u, bytes.NewBuffer(jsonBytes), net.ContentTypeJSON, ep.APIKey) if err != nil { return nil, fmt.Errorf("error sending request: %v", err) diff --git a/spearlet/hostcalls/stt.go b/spearlet/hostcalls/stt.go new file mode 100644 index 0000000..1a3dd16 --- /dev/null +++ b/spearlet/hostcalls/stt.go @@ -0,0 +1,57 @@ +package hostcalls + +import ( + "fmt" + + "github.com/lfedgeai/spear/spearlet/hostcalls/common" + hostcalls "github.com/lfedgeai/spear/spearlet/hostcalls/common" + oai "github.com/lfedgeai/spear/spearlet/hostcalls/openai" +) + +func SpeechToText(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { + // // right now we just call openai SpeechToText + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + + // req := &transform.SpeechToTextRequest{} + // err = req.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // req2 := &oai.OpenAISpeechToTextRequest{ + // Model: req.Model, + // Audio: req.Audio, + // } + // ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeSpeechToText, req2.Model) + // if len(ep) == 0 { + // return nil, fmt.Errorf("error getting endpoint for model %s", req2.Model) + // } + // res, err := oai.OpenAISpeechToText(ep[0], req2) + // if err != nil { + // return nil, fmt.Errorf("error calling openai SpeechToText: %v", err) + // } + + // return res, nil + + return nil, fmt.Errorf("not implemented") +} + +func speechToTextString(audio []byte, model string) (string, error) { + req2 := &oai.OpenAISpeechToTextRequest{ + Model: model, + Audio: audio, + } + ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeSpeechToText, req2.Model) + if len(ep) == 0 { + return "", fmt.Errorf("error getting endpoint for model %s", req2.Model) + } + res, err := oai.OpenAISpeechToText(ep[0], req2) + if err != nil { + return "", fmt.Errorf("error calling openai SpeechToText: %v", err) + } + + return res.Text, nil +} diff --git a/spearlet/hostcalls/tools.go b/spearlet/hostcalls/tools.go new file mode 100644 index 0000000..b81ced9 --- /dev/null +++ b/spearlet/hostcalls/tools.go @@ -0,0 +1,51 @@ +package hostcalls + +import ( + "fmt" + + flatbuffers "github.com/google/flatbuffers/go" + "github.com/lfedgeai/spear/pkg/spear/proto/tool" + hcommon "github.com/lfedgeai/spear/spearlet/hostcalls/common" + log "github.com/sirupsen/logrus" +) + +func NewInternalTool(inv *hcommon.InvocationInfo, + args []byte) ([]byte, error) { + req := tool.GetRootAsInternalToolCreateRequest(args, 0) + if req == nil { + return nil, fmt.Errorf("could not get InternalToolCreateRequest") + } + + toolreg := hcommon.ToolRegistry{ + Name: string(req.Name()), + Description: string(req.Description()), + Params: make(map[string]hcommon.ToolParam), + } + + for i := 0; i < req.ParamsLength(); i++ { + paramSpec := tool.InternalToolCreateParamSpec{} + if !req.Params(¶mSpec, i) { + return nil, fmt.Errorf("could not get param spec") + } + + toolreg.Params[string(paramSpec.Name())] = hcommon.ToolParam{ + Ptype: string(paramSpec.Type()), + Description: string(paramSpec.Description()), + Required: paramSpec.Required(), + } + } + + log.Infof("Registering internal tool %+v", toolreg) + + newId, err := hcommon.RegisterTaskInternalTool(inv.Task, toolreg) + if err != nil { + return nil, err + } + + builder := flatbuffers.NewBuilder(0) + tool.InternalToolCreateResponseStart(builder) + tool.InternalToolCreateResponseAddToolId(builder, int64(newId)) + builder.Finish(tool.InternalToolCreateResponseEnd(builder)) + + return builder.FinishedBytes(), nil +} diff --git a/spearlet/hostcalls/transform.go b/spearlet/hostcalls/transform.go new file mode 100644 index 0000000..b56eb57 --- /dev/null +++ b/spearlet/hostcalls/transform.go @@ -0,0 +1,168 @@ +package hostcalls + +import ( + "fmt" + + "github.com/lfedgeai/spear/pkg/spear/proto/transform" + "github.com/lfedgeai/spear/pkg/spear/proto/transport" + hostcalls "github.com/lfedgeai/spear/spearlet/hostcalls/common" + log "github.com/sirupsen/logrus" +) + +type TransformRegistry struct { + name string + inputTypes []transform.TransformType + outputTypes []transform.TransformType + operations []transform.TransformOperation + cb func(*hostcalls.InvocationInfo, *transform.TransformRequest) ([]byte, error) +} + +var ( + globalRegisteredTransform = []TransformRegistry{ + { + name: "chat_with_tools", + inputTypes: []transform.TransformType{transform.TransformTypeText}, + outputTypes: []transform.TransformType{transform.TransformTypeText}, + operations: []transform.TransformOperation{ + transform.TransformOperationLLM, + transform.TransformOperationTools, + }, + cb: ChatCompletionWithTools, + }, + { + name: "chat", + inputTypes: []transform.TransformType{transform.TransformTypeText}, + outputTypes: []transform.TransformType{transform.TransformTypeText}, + operations: []transform.TransformOperation{transform.TransformOperationLLM}, + cb: ChatCompletionNoTools, + }, + // { + // name: "embeddings", + // inputTypes: []payload.TransformType{payload.TransformTypeText}, + // outputTypes: []payload.TransformType{payload.TransformTypeVector}, + // operations: []payload.TransformOperation{payload.TransformOperationEmbeddings}, + // cb: Embeddings, + // }, + // { + // name: "text-to-speech", + // inputTypes: []payload.TransformType{payload.TransformTypeText}, + // outputTypes: []payload.TransformType{payload.TransformTypeAudio}, + // operations: []payload.TransformOperation{payload.TransformOperationTextToSpeech}, + // cb: TextToSpeech, + // }, + // { + // name: "speech-to-text", + // inputTypes: []payload.TransformType{payload.TransformTypeAudio}, + // outputTypes: []payload.TransformType{payload.TransformTypeText}, + // operations: []payload.TransformOperation{payload.TransformOperationSpeechToText}, + // cb: SpeechToText, + // }, + // { + // name: "text-to-image", + // inputTypes: []payload.TransformType{payload.TransformTypeText}, + // outputTypes: []payload.TransformType{payload.TransformTypeImage}, + // operations: []payload.TransformOperation{payload.TransformOperationTextToImage}, + // cb: TextToImage, + // }, + } +) + +func isSubSetOf[T comparable](a, b []T) bool { + for _, x := range a { + found := false + for _, y := range b { + if x == y { + found = true + break + } + } + if !found { + return false + } + } + return true +} + +func TransformConfig(inv *hostcalls.InvocationInfo, args []byte) ([]byte, error) { + // task := *(inv.Task) + // log.Debugf("Executing hostcall \"%s\" with args %v for task %s", + // payload.HostCallTransformConfig, args, task.ID()) + // // convert args to TransformConfigRequest + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + + // req := &payload.TransformConfigRequest{} + // err = req.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // if req.Reset { + // task.SetVar(t.TVTest, nil) + // return &payload.TransformConfigResponse{ + // Result: "success", + // }, nil + // } + + // if req.Test != "" { + // task.SetVar(t.TVTest, req.Test) + // } + + // return &payload.TransformConfigResponse{ + // Result: "success", + // }, nil + return nil, fmt.Errorf("hostcall not implemented") +} + +func Transform(inv *hostcalls.InvocationInfo, args []byte) ([]byte, error) { + req := transform.GetRootAsTransformRequest(args, 0) + if req == nil { + return nil, fmt.Errorf("could not get TransformRequest") + } + + var candid *TransformRegistry + + inputTypes := make([]transform.TransformType, req.InputTypesLength()) + for i := 0; i < req.InputTypesLength(); i++ { + inputTypes[i] = req.InputTypes(i) + } + outputTypes := make([]transform.TransformType, req.OutputTypesLength()) + for i := 0; i < req.OutputTypesLength(); i++ { + outputTypes[i] = req.OutputTypes(i) + } + operations := make([]transform.TransformOperation, req.OperationsLength()) + for i := 0; i < req.OperationsLength(); i++ { + operations[i] = req.Operations(i) + } + // find the transform registry + for _, reg := range globalRegisteredTransform { + if isSubSetOf(inputTypes, reg.inputTypes) && + isSubSetOf(outputTypes, reg.outputTypes) && + isSubSetOf(operations, reg.operations) { + if candid != nil { + if len(reg.inputTypes) <= len(candid.inputTypes) && + len(reg.outputTypes) <= len(candid.outputTypes) && + len(reg.operations) <= len(candid.operations) { + candid = ® + } + } else { + candid = ® + } + } + } + + if candid != nil { + log.Infof("Using transform registry %s", candid.name) + res, err := candid.cb(inv, req) + if err != nil { + return nil, fmt.Errorf("error calling %s: %v", candid.name, err) + } + + log.Debugf("Transform result: %+v", res) + return res, nil + } + + return nil, fmt.Errorf("hostcall \"%v\" not implemented", transport.MethodTransform) +} diff --git a/spearlet/hostcalls/tts.go b/spearlet/hostcalls/tts.go new file mode 100644 index 0000000..55d0f10 --- /dev/null +++ b/spearlet/hostcalls/tts.go @@ -0,0 +1,56 @@ +package hostcalls + +import ( + "fmt" + + "github.com/lfedgeai/spear/spearlet/hostcalls/common" + hostcalls "github.com/lfedgeai/spear/spearlet/hostcalls/common" + oai "github.com/lfedgeai/spear/spearlet/hostcalls/openai" +) + +func TextToSpeech(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { + // // right now we just call openai TextToSpeech + // req := &transform.TextToSpeechRequest{} + // err := utils.InterfaceToType(&req, args) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // req2 := &oai.OpenAITextToSpeechRequest{ + // Model: req.Model, + // Input: req.Input, + // Voice: req.Voice, + // Format: req.Format, + // } + // ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeTextToSpeech, req2.Model) + // if len(ep) == 0 { + // return nil, fmt.Errorf("error getting endpoint for model %s", req2.Model) + // } + // res, err := oai.OpenAITextToSpeech(ep[0], req2) + // if err != nil { + // return nil, fmt.Errorf("error calling openai TextToSpeech: %v", err) + // } + + // return res, nil + + return nil, fmt.Errorf("not implemented") +} + +func textToSpeechData(text, model, voice, format string) (string, error) { + req2 := &oai.OpenAITextToSpeechRequest{ + Model: model, + Input: text, + Voice: voice, + Format: format, + } + ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeTextToSpeech, req2.Model) + if len(ep) == 0 { + return "", fmt.Errorf("error getting endpoint for model %s", req2.Model) + } + res, err := oai.OpenAITextToSpeech(ep[0], req2) + if err != nil { + return "", fmt.Errorf("error calling openai TextToSpeech: %v", err) + } + + return res.EncodedAudio, nil +} diff --git a/spearlet/hostcalls/vectorstore.go b/spearlet/hostcalls/vectorstore.go new file mode 100644 index 0000000..e7c29b7 --- /dev/null +++ b/spearlet/hostcalls/vectorstore.go @@ -0,0 +1,299 @@ +package hostcalls + +import ( + "context" + "fmt" + + hostcalls "github.com/lfedgeai/spear/spearlet/hostcalls/common" + "github.com/lfedgeai/spear/spearlet/task" + "github.com/qdrant/go-client/qdrant" + log "github.com/sirupsen/logrus" +) + +var ( + globalVectorStoreRegistries = make(map[task.TaskID]*VectorStoreRegistry) +) + +type VectorStore struct { + Name string + NextID uint64 +} + +type VectorStoreRegistry struct { + Stores []*VectorStore + Client *qdrant.Client +} + +type VectorStoreSearchResult struct { + Vector []float32 + Data []byte +} + +func NewVectorStoreRegistry() (*VectorStoreRegistry, error) { + qdrantClient, err := qdrant.NewClient(&qdrant.Config{ + Host: "localhost", + Port: 6334, + }) + if err != nil { + log.Errorf("Error creating qdrant client: %v", err) + return nil, err + } + // list all collections + collections, err := qdrantClient.ListCollections(context.Background()) + if err != nil { + log.Errorf("Error listing collections: %v", err) + return nil, err + } + log.Infof("Collections: %v", collections) + return &VectorStoreRegistry{ + Stores: make([]*VectorStore, 0), + Client: qdrantClient, + }, nil +} + +func (r *VectorStoreRegistry) Create(storeName string, dimensions uint64) (int, error) { + log.Infof("Creating vector store with name %s", storeName) + // duplicated store is not allowed + for i, store := range r.Stores { + if store.Name == storeName { + return i, fmt.Errorf("store with name %s already exists", storeName) + } + } + + // create the vector store in qdrant + err := r.Client.CreateCollection(context.Background(), &qdrant.CreateCollection{ + CollectionName: storeName, + VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ + Size: dimensions, + Distance: qdrant.Distance_Cosine, + }), + }) + if err != nil { + return -1, fmt.Errorf("error creating collection: %v", err) + } + + // create a new vector store with the given name + r.Stores = append(r.Stores, &VectorStore{ + Name: storeName, + NextID: 1, + }) + + return len(r.Stores) - 1, nil +} + +func (r *VectorStoreRegistry) Delete(vid int) error { + log.Infof("Deleting vector store with id %d", vid) + // delete the vector store in qdrant + err := r.Client.DeleteCollection(context.Background(), r.Stores[vid].Name) + if err != nil { + return fmt.Errorf("error deleting collection: %v", err) + } + + // remove the vid-th vector store + r.Stores = append(r.Stores[:vid], r.Stores[vid+1:]...) + + return nil +} + +func (r *VectorStoreRegistry) Insert(vid int, vector []float32, payload []byte) error { + log.Infof("Inserting vector into vector store with id %d", vid) + // insert the vector into qdrant + opInfo, err := r.Client.Upsert(context.Background(), &qdrant.UpsertPoints{ + CollectionName: r.Stores[vid].Name, + Points: []*qdrant.PointStruct{ + { + Id: qdrant.NewIDNum(r.Stores[vid].NextID), + Payload: qdrant.NewValueMap(map[string]interface{}{ + "payload": payload, + }), + Vectors: qdrant.NewVectors(vector...), + }, + }, + }) + if err != nil { + return fmt.Errorf("error upserting points: %v", err) + } + r.Stores[vid].NextID = r.Stores[vid].NextID + 1 + log.Infof("Upsert operation info: %v", opInfo) + return nil +} + +func (r *VectorStoreRegistry) Search(vid int, vector []float32, limit uint64) ([]*VectorStoreSearchResult, error) { + log.Infof("Searching vector in vector store with vid %d and vector %v", vid, vector) + // search the vector in qdrant + result, err := r.Client.Query(context.Background(), &qdrant.QueryPoints{ + CollectionName: r.Stores[vid].Name, + Query: qdrant.NewQuery(vector...), + Limit: &limit, + }) + if err != nil { + return nil, fmt.Errorf("error querying points: %v", err) + } + ret := make([]*VectorStoreSearchResult, len(result)) + for i, res := range result { + if res.Vectors == nil { + log.Infof(fmt.Sprintf("Vector is nil: %v", res)) + ret[i] = &VectorStoreSearchResult{ + Vector: nil, + Data: []byte(res.Payload["payload"].String()), + } + } else { + ret[i] = &VectorStoreSearchResult{ + Vector: res.Vectors.GetVector().Data, + Data: []byte(res.Payload["payload"].String()), + } + } + } + log.Infof("Search result: %+v", ret) + return ret, nil +} + +func VectorStoreCreate(inv *hostcalls.InvocationInfo, args []byte) ([]byte, error) { + // task := *(inv.Task) + // log.Debugf("Executing hostcall \"%s\" with args %v", payload.HostCallVectorStoreCreate, args) + // // verify the type of args is string + // // use json marshal and unmarshal to verify the type + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + // req := payload.VectorStoreCreateRequest{} + // err = req.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // log.Infof("VectorStoreCreate Request: %v", req) + // // create a new vector store + // if _, ok := globalVectorStoreRegistries[task.ID()]; !ok { + // val, err := NewVectorStoreRegistry() + // if err != nil { + // return nil, fmt.Errorf("error creating vector store registry: %v", err) + // } + // globalVectorStoreRegistries[task.ID()] = val + // } + + // vid, err := globalVectorStoreRegistries[task.ID()].Create(req.Name, req.Dimentions) + // if err != nil { + // return nil, fmt.Errorf("error creating vector store: %v", err) + // } + + // // return the response + // return &payload.VectorStoreCreateResponse{ + // VID: vid, + // }, nil + + return nil, fmt.Errorf("not implemented") +} + +func VectorStoreDelete(inv *hostcalls.InvocationInfo, args []byte) ([]byte, error) { + // task := *(inv.Task) + // log.Debugf("Executing hostcall \"%s\" with args %v", payload.HostCallVectorStoreDelete, args) + // // verify the type of args is int + // // use json marshal and unmarshal to verify the type + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + // req := payload.VectorStoreDeleteRequest{} + // err = req.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // log.Infof("VectorStoreDelete Request: %v", req) + // // delete the vector store + // if _, ok := globalVectorStoreRegistries[task.ID()]; !ok { + // return nil, fmt.Errorf("vector store registry not found") + // } + + // err = globalVectorStoreRegistries[task.ID()].Delete(req.VID) + // if err != nil { + // return nil, fmt.Errorf("error deleting vector store: %v", err) + // } + + // // return the response + // return &payload.VectorStoreDeleteResponse{ + // VID: req.VID, + // }, nil + + return nil, fmt.Errorf("not implemented") +} + +func VectorStoreInsert(inv *hostcalls.InvocationInfo, args []byte) ([]byte, error) { + // task := *(inv.Task) + // log.Debugf("Executing hostcall \"%s\" with args %v", payload.HostCallVectorStoreInsert, args) + // // verify the type of args is VectorStoreInsertRequest + // // use json marshal and unmarshal to verify the type + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + // req := payload.VectorStoreInsertRequest{} + // err = req.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // log.Infof("VectorStoreInsert Request: %s", string(jsonBytes)) + // // insert the vector into the vector store + // v, ok := globalVectorStoreRegistries[task.ID()] + // if !ok { + // return nil, fmt.Errorf("vector store registry not found") + // } + + // err = v.Insert(req.VID, req.Vector, req.Data) + // if err != nil { + // return nil, fmt.Errorf("error inserting vector: %v", err) + // } + + // // return the response + // return payload.VectorStoreInsertResponse{ + // VID: req.VID, + // }, nil + + return nil, fmt.Errorf("not implemented") +} + +func VectorStoreSearch(inv *hostcalls.InvocationInfo, args []byte) ([]byte, error) { + // task := *(inv.Task) + // log.Debugf("Executing hostcall \"%s\" with args %v", payload.HostCallVectorStoreSearch, args) + // // verify the type of args is VectorStoreSearchRequest + // // use json marshal and unmarshal to verify the type + // jsonBytes, err := json.Marshal(args) + // if err != nil { + // return nil, fmt.Errorf("error marshalling args: %v", err) + // } + // req := payload.VectorStoreSearchRequest{} + // err = req.Unmarshal(jsonBytes) + // if err != nil { + // return nil, fmt.Errorf("error unmarshalling args: %v", err) + // } + + // log.Infof("VectorStoreSearch Request: %s", string(jsonBytes)) + // // search the vector in the vector store + // v, ok := globalVectorStoreRegistries[task.ID()] + // if !ok { + // return nil, fmt.Errorf("vector store registry not found") + // } + + // result, err := v.Search(req.VID, req.Vector, req.Limit) + // if err != nil { + // return nil, fmt.Errorf("error searching vector: %v", err) + // } + + // // return the response + // res := payload.VectorStoreSearchResponse{ + // VID: req.VID, + // Entries: make([]payload.VectorStoreSearchResponseEntry, len(result)), + // } + // for i, r := range result { + // res.Entries[i] = payload.VectorStoreSearchResponseEntry{ + // Vector: r.Vector, + // Data: r.Data, + // } + // } + // return res, nil + + return nil, fmt.Errorf("not implemented") +} diff --git a/worker/worker.go b/spearlet/spearlet.go similarity index 65% rename from worker/worker.go rename to spearlet/spearlet.go index c93cf40..0fcf2c7 100644 --- a/worker/worker.go +++ b/spearlet/spearlet.go @@ -1,4 +1,4 @@ -package worker +package spearlet import ( "context" @@ -10,20 +10,23 @@ import ( "strconv" "time" + flatbuffers "github.com/google/flatbuffers/go" log "github.com/sirupsen/logrus" "github.com/lfedgeai/spear/pkg/common" - hc "github.com/lfedgeai/spear/worker/hostcalls" - hostcalls "github.com/lfedgeai/spear/worker/hostcalls/common" - "github.com/lfedgeai/spear/worker/task" - _ "github.com/lfedgeai/spear/worker/tools" + "github.com/lfedgeai/spear/pkg/spear/proto/custom" + "github.com/lfedgeai/spear/pkg/spear/proto/transport" + hc "github.com/lfedgeai/spear/spearlet/hostcalls" + hostcalls "github.com/lfedgeai/spear/spearlet/hostcalls/common" + "github.com/lfedgeai/spear/spearlet/task" + _ "github.com/lfedgeai/spear/spearlet/tools" ) var ( logLevel = log.InfoLevel ) -type WorkerConfig struct { +type SpearletConfig struct { Addr string Port string @@ -37,8 +40,8 @@ type WorkerConfig struct { SpearAddr string } -type Worker struct { - cfg *WorkerConfig +type Spearlet struct { + cfg *SpearletConfig mux *http.ServeMux srv *http.Server @@ -58,54 +61,49 @@ type TaskMetaData struct { var ( tmpMetaData = map[int]TaskMetaData{ - 1: { - Id: 1, - Type: task.TaskTypeDocker, - Image: "dummy", - Name: "dummy", - }, - 2: { - Id: 2, - Type: task.TaskTypeDocker, - Image: "voice_chat", - Name: "voice_chat", - }, 3: { Id: 3, Type: task.TaskTypeDocker, - Image: "gen_image", + Image: "gen_image:latest", Name: "gen_image", }, 4: { Id: 4, Type: task.TaskTypeDocker, - Image: "pychat", + Image: "pychat:latest", Name: "pychat", }, 5: { Id: 5, Type: task.TaskTypeDocker, - Image: "pytools", + Image: "pytools:latest", Name: "pytools", }, 6: { Id: 6, Type: task.TaskTypeDocker, - Image: "pyconversation", + Image: "pyconversation:latest", Name: "pyconversation", }, 7: { Id: 7, Type: task.TaskTypeDocker, - Image: "pydummy", + Image: "pydummy:latest", Name: "pydummy", }, + 8: { + Id: 8, + Type: task.TaskTypeDocker, + Image: "pytest-functionality:latest", + Name: "pytest-functionality", + }, } ) -// NewServeWorkerConfig creates a new WorkerConfig -func NewServeWorkerConfig(addr, port string, spath []string, debug bool, spearAddr string) *WorkerConfig { - return &WorkerConfig{ +// NewServeSpearletConfig creates a new SpearletConfig +func NewServeSpearletConfig(addr, port string, spath []string, debug bool, + spearAddr string) *SpearletConfig { + return &SpearletConfig{ Addr: addr, Port: port, SearchPath: spath, @@ -115,8 +113,8 @@ func NewServeWorkerConfig(addr, port string, spath []string, debug bool, spearAd } } -func NewExecWorkerConfig(debug bool, spearAddr string) *WorkerConfig { - return &WorkerConfig{ +func NewExecSpearletConfig(debug bool, spearAddr string) *SpearletConfig { + return &SpearletConfig{ Addr: "", Port: "", Debug: debug, @@ -125,8 +123,8 @@ func NewExecWorkerConfig(debug bool, spearAddr string) *WorkerConfig { } } -func NewWorker(cfg *WorkerConfig) *Worker { - w := &Worker{ +func NewSpearlet(cfg *SpearletConfig) *Spearlet { + w := &Spearlet{ cfg: cfg, mux: http.NewServeMux(), hc: nil, @@ -138,20 +136,20 @@ func NewWorker(cfg *WorkerConfig) *Worker { return w } -func (w *Worker) Initialize() { +func (w *Spearlet) Initialize() { w.addRoutes() w.addHostCalls() w.initializeRuntimes() go w.hc.Run() } -func (w *Worker) addHostCalls() { +func (w *Spearlet) addHostCalls() { for _, hc := range hc.Hostcalls { w.hc.RegisterHostCall(hc) } } -func (w *Worker) initializeRuntimes() { +func (w *Spearlet) initializeRuntimes() { cfg := &task.TaskRuntimeConfig{ Debug: w.cfg.Debug, Cleanup: true, @@ -203,13 +201,15 @@ func funcType(req *http.Request) (task.TaskType, error) { // get the runtime from the headers runtime := headers.Get(HeaderFuncType) if runtime == "" { - return task.TaskTypeUnknown, fmt.Errorf("missing %s header", HeaderFuncType) + return task.TaskTypeUnknown, + fmt.Errorf("missing %s header", HeaderFuncType) } // convert runtime to int i, err := strconv.Atoi(runtime) if err != nil { - return task.TaskTypeUnknown, fmt.Errorf("error parsing %s header: %v", HeaderFuncType, err) + return task.TaskTypeUnknown, + fmt.Errorf("error parsing %s header: %v", HeaderFuncType, err) } switch i { @@ -222,11 +222,12 @@ func funcType(req *http.Request) (task.TaskType, error) { case int(task.TaskTypeWasm): return task.TaskTypeWasm, nil default: - return task.TaskTypeUnknown, fmt.Errorf("invalid %s header: %s", HeaderFuncType, runtime) + return task.TaskTypeUnknown, + fmt.Errorf("invalid %s header: %s", HeaderFuncType, runtime) } } -func (w *Worker) LookupTaskId(name string) (int64, error) { +func (w *Spearlet) LookupTaskId(name string) (int64, error) { for _, v := range tmpMetaData { if v.Name == name { return v.Id, nil @@ -235,7 +236,7 @@ func (w *Worker) LookupTaskId(name string) (int64, error) { return -1, fmt.Errorf("error: task name not found: %s", name) } -func (w *Worker) ListTasks() []string { +func (w *Spearlet) ListTasks() []string { var tasks []string for _, v := range tmpMetaData { tasks = append(tasks, v.Name) @@ -243,7 +244,18 @@ func (w *Worker) ListTasks() []string { return tasks } -func (w *Worker) ExecuteTask(taskId int64, funcType task.TaskType, wait bool, method string, data string) (string, error) { +func (w *Spearlet) ExecuteTaskByName(name string, wait bool, method string, + data string) (string, error) { + for _, v := range tmpMetaData { + if v.Name == name { + return w.ExecuteTask(v.Id, v.Type, wait, method, data) + } + } + return "", fmt.Errorf("error: task name not found: %s", name) +} + +func (w *Spearlet) ExecuteTask(taskId int64, funcType task.TaskType, wait bool, + method string, data string) (string, error) { rt, err := task.GetTaskRuntime(funcType) if err != nil { return "", fmt.Errorf("error: %v", err) @@ -258,13 +270,15 @@ func (w *Worker) ExecuteTask(taskId int64, funcType task.TaskType, wait bool, me return "", fmt.Errorf("error: invalid task type: %d", funcType) } + log.Infof("Using image: %s", meta.Image) + randSrc := rand.NewSource(time.Now().UnixNano()) randGen := rand.New(randSrc) newTask, err := rt.CreateTask(&task.TaskConfig{ - Name: fmt.Sprintf("task-%s-%d", meta.Name, randGen.Intn(10000)), - Cmd: "/start", //"sh", //"./dummy_task", - Args: []string{}, - Image: meta.Image, + Name: fmt.Sprintf("task-%s-%d", meta.Name, randGen.Intn(10000)), + Cmd: "/start", //"sh", //"./dummy_task", + Args: []string{}, + Image: meta.Image, HostAddr: w.spearAddr, }) if err != nil { @@ -279,17 +293,36 @@ func (w *Worker) ExecuteTask(taskId int64, funcType task.TaskType, wait bool, me newTask.Start() res := "" - if r, err := w.commMgr.SendOutgoingRPCRequest(newTask, method, data); err != nil { + builder := flatbuffers.NewBuilder(512) + methodOff := builder.CreateString(method) + dataOff := builder.CreateString(data) + custom.CustomRequestStart(builder) + custom.CustomRequestAddMethodStr(builder, methodOff) + custom.CustomRequestAddParamsStr(builder, dataOff) + builder.Finish(custom.CustomRequestEnd(builder)) + + if r, err := w.commMgr.SendOutgoingRPCRequest(newTask, transport.MethodCustom, + builder.FinishedBytes()); err != nil { return "", fmt.Errorf("error: %v", err) } else { + if len(r.ResponseBytes()) == 0 { + return "", nil // no response + } + customResp := custom.GetRootAsCustomResponse(r.ResponseBytes(), 0) // marshal the result - if resTmp, err := json.Marshal(r.Result); err != nil { + if resTmp, err := json.Marshal(customResp.DataBytes()); err != nil { return "", fmt.Errorf("error: %v", err) } else { res = string(resTmp) } } + // terminate the task by sending a signal + if err := w.commMgr.SendOutgoingRPCSignal(newTask, transport.SignalTerminate, + []byte{}); err != nil { + return "", fmt.Errorf("error: %v", err) + } + if wait { // wait for the task to finish newTask.Wait() @@ -298,7 +331,7 @@ func (w *Worker) ExecuteTask(taskId int64, funcType task.TaskType, wait bool, me return res, nil } -func (w *Worker) addRoutes() { +func (w *Spearlet) addRoutes() { w.mux.HandleFunc("/health", func(resp http.ResponseWriter, req *http.Request) { resp.Write([]byte("OK")) }) @@ -331,7 +364,8 @@ func (w *Worker) addRoutes() { respError(resp, fmt.Sprintf("Error: %v", err)) return } - res, err := w.ExecuteTask(taskId, funcType, !funcIsAsync, "handle", string(buf[:n])) + res, err := w.ExecuteTask(taskId, funcType, !funcIsAsync, "handle", + string(buf[:n])) if err != nil { respError(resp, fmt.Sprintf("Error: %v", err)) return @@ -340,8 +374,8 @@ func (w *Worker) addRoutes() { }) } -func (w *Worker) StartServer() { - log.Infof("Starting worker on %s:%s", w.cfg.Addr, w.cfg.Port) +func (w *Spearlet) StartServer() { + log.Infof("Starting spearlet on %s:%s", w.cfg.Addr, w.cfg.Port) srv := &http.Server{ Addr: w.cfg.Addr + ":" + w.cfg.Port, Handler: w.mux, @@ -356,8 +390,8 @@ func (w *Worker) StartServer() { } } -func (w *Worker) Stop() { - log.Debugf("Stopping worker") +func (w *Spearlet) Stop() { + log.Debugf("Stopping spearlet") if w.srv != nil { w.srv.Shutdown(context.Background()) } diff --git a/worker/task/const.go b/spearlet/task/const.go similarity index 100% rename from worker/task/const.go rename to spearlet/task/const.go diff --git a/worker/task/docker.go b/spearlet/task/docker.go similarity index 91% rename from worker/task/docker.go rename to spearlet/task/docker.go index 53b5477..77dc393 100644 --- a/worker/task/docker.go +++ b/spearlet/task/docker.go @@ -37,14 +37,15 @@ func (p *DockerTask) ID() TaskID { } func (p *DockerTask) Start() error { - err := p.runtime.cli.ContainerStart(context.TODO(), p.container.ID, container.StartOptions{}) + err := p.runtime.cli.ContainerStart(context.TODO(), p.container.ID, + container.StartOptions{}) if err != nil { return err } go func() { <-p.connReady - log.Debugf("Connection ready for task %s", p.name) + log.Infof("Connection ready for task %s", p.name) // input goroutine go func() { @@ -102,12 +103,13 @@ func (p *DockerTask) Start() error { }() // get stdin and stdout - val, err := p.runtime.cli.ContainerAttach(context.TODO(), p.container.ID, container.AttachOptions{ - Stream: true, - Stdin: true, - Stdout: false, - Stderr: false, - }) + val, err := p.runtime.cli.ContainerAttach(context.TODO(), p.container.ID, + container.AttachOptions{ + Stream: true, + Stdin: true, + Stdout: false, + Stderr: false, + }) if err != nil { return err } @@ -146,7 +148,8 @@ func (p *DockerTask) Start() error { data = data[8:] n = n - 8 // big endian size - sz := int(header[4])<<24 | int(header[5])<<16 | int(header[6])<<8 | int(header[7]) + sz := int(header[4])<<24 | int(header[5])<<16 | + int(header[6])<<8 | int(header[7]) log.Debugf("Size: %d, ReadLen: %d, Got data: %s", sz, n, data) if header[0] == 0x01 { // stdout @@ -176,7 +179,8 @@ func (p *DockerTask) Start() error { } func (p *DockerTask) Stop() error { - err := p.runtime.cli.ContainerStop(context.TODO(), p.container.ID, container.StopOptions{}) + err := p.runtime.cli.ContainerStop(context.TODO(), p.container.ID, + container.StopOptions{}) if err != nil { return err } @@ -201,7 +205,8 @@ func (p *DockerTask) CommChannels() (chan Message, chan Message, error) { } func (p *DockerTask) Wait() (int, error) { - c, err := p.runtime.cli.ContainerWait(context.TODO(), p.container.ID, container.WaitConditionNotRunning) + c, err := p.runtime.cli.ContainerWait(context.TODO(), p.container.ID, + container.WaitConditionNotRunning) select { case <-c: return 0, nil @@ -212,8 +217,9 @@ func (p *DockerTask) Wait() (int, error) { } func (p *DockerTask) NextRequestID() uint64 { - p.reqId++ - return p.reqId + res := p.reqId + p.reqId += 1 + return res } func (p *DockerTask) SetVar(key TaskVar, value interface{}) { diff --git a/worker/task/docker/utils.go b/spearlet/task/docker/utils.go similarity index 100% rename from worker/task/docker/utils.go rename to spearlet/task/docker/utils.go diff --git a/worker/task/docker_rt.go b/spearlet/task/docker_rt.go similarity index 96% rename from worker/task/docker_rt.go rename to spearlet/task/docker_rt.go index 09d1aa6..8828158 100644 --- a/worker/task/docker_rt.go +++ b/spearlet/task/docker_rt.go @@ -11,7 +11,7 @@ import ( "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" - "github.com/lfedgeai/spear/worker/task/docker" + "github.com/lfedgeai/spear/spearlet/task/docker" log "github.com/sirupsen/logrus" ) @@ -122,7 +122,8 @@ func (d *DockerTaskRuntime) CreateTask(cfg *TaskConfig) (Task, error) { AttachStderr: true, OpenStdin: true, Env: []string{ - fmt.Sprintf("SERVICE_ADDR=%s:%s", cfg.HostAddr, DockerRuntimeTcpListenPort), + fmt.Sprintf("SERVICE_ADDR=%s:%s", cfg.HostAddr, + DockerRuntimeTcpListenPort), fmt.Sprintf("SECRET=%d", secretGenerated), }, } @@ -149,7 +150,7 @@ func (d *DockerTaskRuntime) CreateTask(cfg *TaskConfig) (Task, error) { conn: nil, connReady: make(chan struct{}), - reqId: 0, + reqId: 1, taskVars: make(map[TaskVar]interface{}), taskVarsMu: sync.RWMutex{}, diff --git a/worker/task/proc.go b/spearlet/task/proc.go similarity index 94% rename from worker/task/proc.go rename to spearlet/task/proc.go index 5a50237..129ba3e 100644 --- a/worker/task/proc.go +++ b/spearlet/task/proc.go @@ -3,6 +3,7 @@ package task import ( "fmt" "os/exec" + "strconv" "sync" log "github.com/sirupsen/logrus" @@ -29,7 +30,7 @@ type ProcessTask struct { } func (p *ProcessTask) ID() TaskID { - return TaskID(p.cmd.Process.Pid) + return TaskID(strconv.Itoa(p.cmd.Process.Pid)) } func (p *ProcessTask) Start() error { @@ -87,8 +88,9 @@ func (p *ProcessTask) Wait() (int, error) { } func (p *ProcessTask) NextRequestID() uint64 { - p.reqId++ - return p.reqId + res := p.reqId + p.reqId += 1 + return res } func (p *ProcessTask) SetVar(key TaskVar, value interface{}) { @@ -118,7 +120,7 @@ func NewProcessTask(cfg *TaskConfig) *ProcessTask { status: TaskStatusInit, result: nil, done: make(chan struct{}), - reqId: 0, + reqId: 1, taskVars: make(map[TaskVar]interface{}), taskVarsMu: sync.RWMutex{}, } diff --git a/worker/task/proc_rt.go b/spearlet/task/proc_rt.go similarity index 100% rename from worker/task/proc_rt.go rename to spearlet/task/proc_rt.go diff --git a/worker/task/task.go b/spearlet/task/task.go similarity index 98% rename from worker/task/task.go rename to spearlet/task/task.go index 4119422..364a29e 100644 --- a/worker/task/task.go +++ b/spearlet/task/task.go @@ -8,10 +8,10 @@ import ( type TaskConfig struct { // task name - Name string - Image string - Cmd string - Args []string + Name string + Image string + Cmd string + Args []string HostAddr string } diff --git a/worker/tools/contact.go b/spearlet/tools/contact.go similarity index 90% rename from worker/tools/contact.go rename to spearlet/tools/contact.go index 0b5657b..942e645 100644 --- a/worker/tools/contact.go +++ b/spearlet/tools/contact.go @@ -3,13 +3,15 @@ package tools import ( "os/exec" - hccommon "github.com/lfedgeai/spear/worker/hostcalls/common" + hccommon "github.com/lfedgeai/spear/spearlet/hostcalls/common" log "github.com/sirupsen/logrus" ) var contactTools = []hccommon.ToolRegistry{ { + ToolType: hccommon.ToolType_Builtin, Name: "search_contact_email", + Id: hccommon.BuiltinToolID_SearchContactEmail, Description: "Search for a person's email address in Contacts", Params: map[string]hccommon.ToolParam{ "name": { @@ -18,7 +20,6 @@ var contactTools = []hccommon.ToolRegistry{ Required: true, }, }, - Cb: "", CbBuiltIn: func(inv *hccommon.InvocationInfo, args interface{}) (interface{}, error) { // use apple script to search for contact log.Infof("Searching for contact with name %s", args.(map[string]interface{})["name"].(string)) diff --git a/worker/tools/datetime.go b/spearlet/tools/datetime.go similarity index 81% rename from worker/tools/datetime.go rename to spearlet/tools/datetime.go index 5249a74..9f174f7 100644 --- a/worker/tools/datetime.go +++ b/spearlet/tools/datetime.go @@ -4,19 +4,22 @@ import ( "fmt" "time" - hccommon "github.com/lfedgeai/spear/worker/hostcalls/common" + hccommon "github.com/lfedgeai/spear/spearlet/hostcalls/common" ) var dtTools = []hccommon.ToolRegistry{ { + ToolType: hccommon.ToolType_Builtin, Name: "datetime", + Id: hccommon.BuiltinToolID_Datetime, Description: "Get current date and time, including timezone information", Params: map[string]hccommon.ToolParam{}, - Cb: "", CbBuiltIn: datetime, }, { + ToolType: hccommon.ToolType_Builtin, Name: "sleep", + Id: hccommon.BuiltinToolID_Sleep, Description: "Sleep for a specified number of seconds", Params: map[string]hccommon.ToolParam{ "seconds": { @@ -25,7 +28,6 @@ var dtTools = []hccommon.ToolRegistry{ Required: true, }, }, - Cb: "", CbBuiltIn: sleep, }, } diff --git a/worker/tools/email.go b/spearlet/tools/email.go similarity index 92% rename from worker/tools/email.go rename to spearlet/tools/email.go index 5b9c0af..4766ddf 100644 --- a/worker/tools/email.go +++ b/spearlet/tools/email.go @@ -3,20 +3,23 @@ package tools import ( "os/exec" - hccommon "github.com/lfedgeai/spear/worker/hostcalls/common" + hccommon "github.com/lfedgeai/spear/spearlet/hostcalls/common" log "github.com/sirupsen/logrus" ) var emailTools = []hccommon.ToolRegistry{ { + ToolType: hccommon.ToolType_Builtin, Name: "list_open_emails", + Id: hccommon.BuiltinToolID_ListOpenEmails, Description: "List all open email drafts window", Params: map[string]hccommon.ToolParam{}, - Cb: "", CbBuiltIn: listOpenEmails, }, { - Name: "compose_email", + ToolType: hccommon.ToolType_Builtin, + Name: "compose_email", + Id: hccommon.BuiltinToolID_ComposeEmail, Description: `Compose an email, open a draft window with the email pre-filled. NOTE: the email has to be a valid email address, you need to get it from other tools or from the user directly`, Params: map[string]hccommon.ToolParam{ @@ -36,11 +39,12 @@ var emailTools = []hccommon.ToolRegistry{ Required: true, }, }, - Cb: "", CbBuiltIn: composeEmail, }, { - Name: "send_email_draft_window", + ToolType: hccommon.ToolType_Builtin, + Name: "send_email_draft_window", + Id: hccommon.BuiltinToolID_SendEmailDraftWindow, Description: `Activate the email draft window and send the email. NOTE: 1. Call the tool "list_open_emails" to list available email windows before calling this function. 2. Before call this tool to actually send the email, assitant needs to stop & ask the user to say yes`, @@ -51,7 +55,6 @@ var emailTools = []hccommon.ToolRegistry{ Required: true, }, }, - Cb: "", CbBuiltIn: sendEmailDraftWindow, }, } diff --git a/worker/tools/mouse.go b/spearlet/tools/mouse.go similarity index 79% rename from worker/tools/mouse.go rename to spearlet/tools/mouse.go index 0a61304..b970af2 100644 --- a/worker/tools/mouse.go +++ b/spearlet/tools/mouse.go @@ -4,15 +4,16 @@ import ( "time" "github.com/go-vgo/robotgo" - hccommon "github.com/lfedgeai/spear/worker/hostcalls/common" + hccommon "github.com/lfedgeai/spear/spearlet/hostcalls/common" ) var mouseTools = []hccommon.ToolRegistry{ { + ToolType: hccommon.ToolType_Builtin, Name: "mouse_right_click", + Id: hccommon.BuiltinToolID_MouseRightClick, Description: `Right click the mouse at the current location.`, Params: map[string]hccommon.ToolParam{}, - Cb: "", CbBuiltIn: func(inv *hccommon.InvocationInfo, args interface{}) (interface{}, error) { robotgo.Toggle("right") time.Sleep(300 * time.Millisecond) @@ -21,10 +22,11 @@ var mouseTools = []hccommon.ToolRegistry{ }, }, { + ToolType: hccommon.ToolType_Builtin, Name: "mouse_left_click", + Id: hccommon.BuiltinToolID_MouseLeftClick, Description: `Left click the mouse at the current location.`, Params: map[string]hccommon.ToolParam{}, - Cb: "", CbBuiltIn: func(inv *hccommon.InvocationInfo, args interface{}) (interface{}, error) { robotgo.Toggle("left") time.Sleep(300 * time.Millisecond) diff --git a/worker/tools/phone.go b/spearlet/tools/phone.go similarity index 90% rename from worker/tools/phone.go rename to spearlet/tools/phone.go index 251d412..1fc93d0 100644 --- a/worker/tools/phone.go +++ b/spearlet/tools/phone.go @@ -4,7 +4,7 @@ import ( "fmt" "os" - hccommon "github.com/lfedgeai/spear/worker/hostcalls/common" + hccommon "github.com/lfedgeai/spear/spearlet/hostcalls/common" "github.com/twilio/twilio-go" twilioApi "github.com/twilio/twilio-go/rest/api/v2010" @@ -18,7 +18,9 @@ var ( var phoneTools = []hccommon.ToolRegistry{ { + ToolType: hccommon.ToolType_Builtin, Name: "phone_call", + Id: hccommon.BuiltinToolID_PhoneCall, Description: "Call a phone number and play a message", Params: map[string]hccommon.ToolParam{ "phone_number": { @@ -32,7 +34,6 @@ var phoneTools = []hccommon.ToolRegistry{ Required: true, }, }, - Cb: "", CbBuiltIn: func(inv *hccommon.InvocationInfo, args interface{}) (interface{}, error) { if twilioAccountSid == "" || twilioApiSecret == "" { return nil, fmt.Errorf("twilio credentials not set") diff --git a/worker/tools/screen.go b/spearlet/tools/screen.go similarity index 87% rename from worker/tools/screen.go rename to spearlet/tools/screen.go index 92d5dde..951aaa8 100644 --- a/worker/tools/screen.go +++ b/spearlet/tools/screen.go @@ -7,12 +7,14 @@ import ( "github.com/kbinani/screenshot" - hccommon "github.com/lfedgeai/spear/worker/hostcalls/common" + hccommon "github.com/lfedgeai/spear/spearlet/hostcalls/common" ) var screenTools = []hccommon.ToolRegistry{ { + ToolType: hccommon.ToolType_Builtin, Name: "full_screenshot", + Id: hccommon.BuiltinToolID_FullScreenshot, Description: `Take screenshots of everything on all screens, and save them to files`, Params: map[string]hccommon.ToolParam{ "filename-prefix": { @@ -21,7 +23,6 @@ var screenTools = []hccommon.ToolRegistry{ Required: true, }, }, - Cb: "", CbBuiltIn: screenshotCall, }, } diff --git a/worker/tools/web.go b/spearlet/tools/web.go similarity index 87% rename from worker/tools/web.go rename to spearlet/tools/web.go index 27a0f8b..7f1f5eb 100644 --- a/worker/tools/web.go +++ b/spearlet/tools/web.go @@ -5,7 +5,7 @@ import ( "fmt" "os" - hccommon "github.com/lfedgeai/spear/worker/hostcalls/common" + hccommon "github.com/lfedgeai/spear/spearlet/hostcalls/common" log "github.com/sirupsen/logrus" "github.com/chromedp/chromedp" @@ -14,7 +14,9 @@ import ( var webTools = []hccommon.ToolRegistry{ { + ToolType: hccommon.ToolType_Builtin, Name: "open_url", + Id: hccommon.BuiltinToolID_OpenURL, Description: `Open a URL in the default browser`, Params: map[string]hccommon.ToolParam{ "url": { @@ -23,11 +25,12 @@ var webTools = []hccommon.ToolRegistry{ Required: true, }, }, - Cb: "", CbBuiltIn: openUrl, }, { + ToolType: hccommon.ToolType_Builtin, Name: "scroll_down", + Id: hccommon.BuiltinToolID_ScrollDown, Description: `Scroll down the page using arrowdown key`, Params: map[string]hccommon.ToolParam{ "times": { @@ -36,11 +39,12 @@ var webTools = []hccommon.ToolRegistry{ Required: true, }, }, - Cb: "", CbBuiltIn: scrollDown, }, { + ToolType: hccommon.ToolType_Builtin, Name: "scroll_up", + Id: hccommon.BuiltinToolID_ScrollUp, Description: `Scroll up the page using arrowup key`, Params: map[string]hccommon.ToolParam{ "times": { @@ -49,28 +53,30 @@ var webTools = []hccommon.ToolRegistry{ Required: true, }, }, - Cb: "", CbBuiltIn: scrollUp, }, { + ToolType: hccommon.ToolType_Builtin, Name: "page_up", + Id: hccommon.BuiltinToolID_PageUp, Description: `Scroll up the page using pageup key`, Params: map[string]hccommon.ToolParam{}, - Cb: "", CbBuiltIn: pageUp, }, { + ToolType: hccommon.ToolType_Builtin, Name: "page_down", + Id: hccommon.BuiltinToolID_PageDown, Description: `Scroll down the page using pagedown key`, Params: map[string]hccommon.ToolParam{}, - Cb: "", CbBuiltIn: pageDown, }, { + ToolType: hccommon.ToolType_Builtin, Name: "web_screenshot", + Id: hccommon.BuiltinToolID_WebScreenshot, Description: `Take a screenshot of the current web page. This won't take a screenshot of the entire screen`, Params: map[string]hccommon.ToolParam{}, - Cb: "", CbBuiltIn: webScreenshot, }, } diff --git a/worker/types.go b/spearlet/types.go similarity index 87% rename from worker/types.go rename to spearlet/types.go index 72d9654..47cadec 100644 --- a/worker/types.go +++ b/spearlet/types.go @@ -1,4 +1,4 @@ -package worker +package spearlet const ( HeaderFuncId = "Spear-Func-Id" diff --git a/test/functionality_test.go b/test/functionality_test.go new file mode 100644 index 0000000..19ebc78 --- /dev/null +++ b/test/functionality_test.go @@ -0,0 +1,25 @@ +package test + +import ( + "testing" + + "github.com/lfedgeai/spear/pkg/common" + "github.com/lfedgeai/spear/spearlet" +) + +func TestFunctionality(t *testing.T) { + // create config + config := spearlet.NewExecSpearletConfig(true, common.SpearPlatformAddress) + w := spearlet.NewSpearlet(config) + w.Initialize() + + res, err := w.ExecuteTaskByName("pytest-functionality", true, "handle", "") + if err != nil { + t.Fatalf("Error executing workload: %v", err) + } + if len(res) > 1024 { + res = res[:1024] + "..." + } + t.Logf("Workload execution result: %v", res) + w.Stop() +} diff --git a/test/simple_req_test.go b/test/simple_req_test.go index 942d1cc..7ecd15a 100644 --- a/test/simple_req_test.go +++ b/test/simple_req_test.go @@ -1,100 +1,19 @@ package test import ( - "fmt" - "net/http" "testing" - "bytes" - "github.com/lfedgeai/spear/pkg/common" - "github.com/lfedgeai/spear/pkg/tools/docker" - "github.com/lfedgeai/spear/worker" - "github.com/lfedgeai/spear/worker/task" + "github.com/lfedgeai/spear/spearlet" ) -func TestSimpleReq(t *testing.T) { - // TestSimpleReq tests simple requests to the worker - // ┌──────────────────┐ - // │ │ - // │ │ - // │ Docker │ - // │ Vector Store │ - // │ │ - // └───────────┬──────┘ - // ▲ │ - // │ ▼ - // ┌──────┴───────────┐ - // │ │ - // │ │ - // │ Worker │ - // │ │ - // │ │ - // └────────────┬─────┘ - // ▲ │ - // │ ▼ - // ┌─────┴────────────┐ - // │ │ - // │ │ - // │ Task │ - // │ │ - // │ │ - // └──────────────────┘ - - // setup the test environment - s := docker.NewTestSetup() - defer s.TearDown() - // send a http request to the server and check the response - - // create a http client - client := &http.Client{} - - // create a http request - req, err := http.NewRequest("GET", "http://localhost:8080", bytes.NewBuffer( - []byte( - `this is a - multiline test`, - ), - )) - if err != nil { - t.Fatalf("Error: %v", err) - } - - // add headers - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Accept", "application/json") - req.Header.Add("Spear-Func-Id", "1") - req.Header.Add("Spear-Func-Type", "1") - - // send the request - resp, err := client.Do(req) - if err != nil { - t.Fatalf("Error: %v", err) - } - - // check the response - if resp.StatusCode != http.StatusOK { - msg := make([]byte, 1024) - n, _ := resp.Body.Read(msg) - t.Fatalf("Error: %v %s", resp.Status, msg[:n]) - } - - // print body - msg := make([]byte, 1024) - n, _ := resp.Body.Read(msg) - fmt.Printf("Response: %s\n", msg[:n]) - - // close the response body - resp.Body.Close() -} - -func TestLocalDummy(t *testing.T) { +func TestLocalPydummy(t *testing.T) { // create config - config := worker.NewExecWorkerConfig(true, common.SpearPlatformAddress) - w := worker.NewWorker(config) + config := spearlet.NewExecSpearletConfig(true, common.SpearPlatformAddress) + w := spearlet.NewSpearlet(config) w.Initialize() - res, err := w.ExecuteTask(1, task.TaskTypeDocker, true, "handle", "hi") + res, err := w.ExecuteTaskByName("pydummy", true, "handle", "") if err != nil { t.Fatalf("Error executing workload: %v", err) } @@ -102,33 +21,19 @@ func TestLocalDummy(t *testing.T) { w.Stop() } -func TestLocalPydummy(t *testing.T) { +func TestLocalGenImage(t *testing.T) { // create config - config := worker.NewExecWorkerConfig(true, common.SpearPlatformAddress) - w := worker.NewWorker(config) + config := spearlet.NewExecSpearletConfig(true, common.SpearPlatformAddress) + w := spearlet.NewSpearlet(config) w.Initialize() - res, err := w.ExecuteTask(7, task.TaskTypeDocker, true, "handle", "") + res, err := w.ExecuteTaskByName("gen_image", true, "handle", "a red bird.") if err != nil { t.Fatalf("Error executing workload: %v", err) } + if len(res) > 1024 { + res = res[:1024] + "..." + } t.Logf("Workload execution result: %v", res) w.Stop() } - -// func TestLocalGenImage(t *testing.T) { -// // create config -// config := worker.NewExecWorkerConfig(true, common.SpearPlatformAddress) -// w := worker.NewWorker(config) -// w.Initialize() - -// res, err := w.ExecuteTask(3, task.TaskTypeDocker, true, "handle", "a red bird.","") -// if err != nil { -// t.Fatalf("Error executing workload: %v", err) -// } -// if len(res) > 1024 { -// res = res[:1024] + "..." -// } -// t.Logf("Workload execution result: %v", res) -// w.Stop() -// } diff --git a/test/test_guide.md b/test/test_guide.md index 1f80839..5c767e6 100644 --- a/test/test_guide.md +++ b/test/test_guide.md @@ -48,7 +48,7 @@ Example Result: ```bash === RUN TestSimpleReq time="2024-12-18T15:49:56+08:00" level=info msg="Starting docker hostcall TCP server on port 8502" -time="2024-12-18T15:50:06+08:00" level=info msg="Starting worker on localhost:8080" +time="2024-12-18T15:50:06+08:00" level=info msg="Starting spearlet on localhost:8080" time="2024-12-18T15:50:11+08:00" level=info msg="Using transform registry chat_with_tools" time="2024-12-18T15:50:11+08:00" level=info msg="Using model gpt-4o" time="2024-12-18T15:50:11+08:00" level=info msg="Found 1 endpoints for gpt-4o: [{openai-toolchat gpt-4o https://api.chatanywhere.tech/v1 ******** /chat/completions}]" diff --git a/worker/hostcalls/chat.go b/worker/hostcalls/chat.go deleted file mode 100644 index a87de3f..0000000 --- a/worker/hostcalls/chat.go +++ /dev/null @@ -1,385 +0,0 @@ -package hostcalls - -import ( - "encoding/json" - "fmt" - - "github.com/lfedgeai/spear/pkg/rpc" - "github.com/lfedgeai/spear/pkg/rpc/payload" - "github.com/lfedgeai/spear/worker/hostcalls/common" - hcommon "github.com/lfedgeai/spear/worker/hostcalls/common" - hcopenai "github.com/lfedgeai/spear/worker/hostcalls/openai" - "github.com/lfedgeai/spear/worker/task" - log "github.com/sirupsen/logrus" -) - -type ChatMessage struct { - Index int `json:"index"` - Metadata map[string]interface{} `json:"metadata"` - Content string `json:"content"` -} - -type ChatCompletionMemory struct { - Messages []ChatMessage `json:"messages"` -} - -func NewChatCompletionMemory() *ChatCompletionMemory { - return &ChatCompletionMemory{ - Messages: make([]ChatMessage, 0), - } -} - -func (m *ChatCompletionMemory) AddMessage(msg ChatMessage) { - m.Messages = append(m.Messages, msg) -} - -func (m *ChatCompletionMemory) Clear() { - m.Messages = make([]ChatMessage, 0) -} - -func (m *ChatCompletionMemory) GetMessages() []ChatMessage { - return m.Messages -} - -func ChatCompletionNoTools(inv *hcommon.InvocationInfo, args interface{}) (interface{}, error) { - // verify the type of args is ChatCompletionRequest - // use json marshal and unmarshal to verify the type - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - chatReq := payload.ChatCompletionRequest{} - err = chatReq.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - if chatReq.ToolsetId != "" { - log.Infof("Tools are not supported in this function") - return nil, fmt.Errorf("tools are not supported in this function") - } - - log.Infof("Using model %s", chatReq.Model) - - msgList, err := innerChatCompletionNoTools(inv, &chatReq) - if err != nil { - return nil, fmt.Errorf("error calling innerChatCompletionNoTools: %v", err) - } - - var res2 payload.ChatCompletionResponseV2 - res2.Messages = make([]payload.ChatMessageV2, len(msgList)) - for i, msg := range msgList { - md := map[string]interface{}{ - "role": msg.Metadata["role"], - } - if msg.Metadata["reason"] != nil { - md["reason"] = msg.Metadata["reason"] - } - res2.Messages[i] = payload.ChatMessageV2{ - Metadata: md, - Content: msg.Content, - } - } - return res2, nil -} - -func ChatCompletionWithTools(inv *hcommon.InvocationInfo, args interface{}) (interface{}, error) { - // verify the type of args is ChatCompletionRequest - // use json marshal and unmarshal to verify the type - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - chatReq := payload.ChatCompletionRequest{} - err = chatReq.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - log.Infof("Using model %s", chatReq.Model) - - msgList, err := innerChatCompletionWithTools(inv, &chatReq) - if err != nil { - return nil, fmt.Errorf("error calling innerChatCompletionWithTools: %v", err) - } - - var res2 payload.ChatCompletionResponseV2 - res2.Messages = make([]payload.ChatMessageV2, len(msgList)) - for i, msg := range msgList { - md := map[string]interface{}{ - "role": msg.Metadata["role"], - } - if msg.Metadata["reason"] != nil { - md["reason"] = msg.Metadata["reason"] - } - if msg.Metadata["tool_call_id"] != nil { - md["tool_call_id"] = msg.Metadata["tool_call_id"] - } - if msg.Metadata["tool_calls"] != nil { - md["tool_calls"] = msg.Metadata["tool_calls"] - } - res2.Messages[i] = payload.ChatMessageV2{ - Metadata: md, - Content: msg.Content, - } - } - return res2, nil -} - -func setupOpenAITools(chatReq *hcopenai.OpenAIChatCompletionRequest, task task.Task, toolsetId hcommon.ToolsetId) error { - toolset, ok := GetToolset(task, toolsetId) - if !ok { - return fmt.Errorf("toolset not found") - } - tools := make([]*hcommon.ToolRegistry, 0) - for _, toolId := range toolset.ToolsIds { - tool, ok := GetToolById(task, toolId) - if ok { - tools = append(tools, tool) - } - } - if len(tools) == 0 { - return fmt.Errorf("no tools found in toolset") - } - chatReq.Tools = make([]hcopenai.OpenAIChatToolFunction, len(tools)) - for i, tool := range tools { - requiredParams := make([]string, 0) - chatReq.Tools[i] = hcopenai.OpenAIChatToolFunction{ - Type: "function", - Func: hcopenai.OpenAIChatToolFunctionSub{ - Name: tool.Name, - Description: tool.Description, - Parameters: hcopenai.OpenAIChatToolParameter{ - Type: "object", - AdditionalProperties: false, - Properties: make(map[string]hcopenai.OpenAIChatToolParameterProperty), - }, - }, - } - for k, v := range tool.Params { - chatReq.Tools[i].Func.Parameters.Properties[k] = hcopenai.OpenAIChatToolParameterProperty{ - Type: v.Ptype, - Description: v.Description, - } - if v.Required { - requiredParams = append(requiredParams, k) - } - } - chatReq.Tools[i].Func.Parameters.Required = requiredParams - // log.Infof("Tool: %v", chatReq.Tools[i]) - } - return nil -} - -func innerChatCompletionWithTools(inv *hcommon.InvocationInfo, chatReq *payload.ChatCompletionRequest) ([]ChatMessage, error) { - task := *(inv.Task) - - mem := NewChatCompletionMemory() - for _, msg := range chatReq.Messages { - tmp := ChatMessage{ - Metadata: msg.Metadata, - Content: msg.Content, - } - mem.AddMessage(tmp) - } - - finished := false - var respData *hcopenai.OpenAIChatCompletionResponse - var err error - for !finished { - // create a new chat request - openAiChatReq2 := hcopenai.OpenAIChatCompletionRequest{ - Model: chatReq.Model, - Messages: []hcopenai.OpenAIChatMessage{}, - } - for _, msg := range mem.GetMessages() { - tmp := hcopenai.OpenAIChatMessage{ - Content: msg.Content, - } - if msg.Metadata["role"] != nil { - tmp.Role = msg.Metadata["role"].(string) - } - if msg.Metadata["tool_calls"] != nil { - log.Debugf("Tool calls: %v", msg.Metadata["tool_calls"]) - switch msg.Metadata["tool_calls"].(type) { - case []hcopenai.OpenAIChatToolCall: - tmp.ToolCalls = msg.Metadata["tool_calls"].([]hcopenai.OpenAIChatToolCall) - case []interface{}: - // marshal the interface{} to json and unmarshal to OpenAIChatToolCall - toolCalls := msg.Metadata["tool_calls"].([]interface{}) - toolCallsStr, err := json.Marshal(toolCalls) - if err != nil { - return nil, fmt.Errorf("error marshalling tool calls: %v", err) - } - var toolCalls2 []hcopenai.OpenAIChatToolCall - err = json.Unmarshal(toolCallsStr, &toolCalls2) - if err != nil { - return nil, fmt.Errorf("error unmarshalling tool calls: %v", err) - } - tmp.ToolCalls = toolCalls2 - default: - return nil, fmt.Errorf("unexpected type for tool_calls") - } - } - if msg.Metadata["tool_call_id"] != nil { - tmp.ToolCallId = msg.Metadata["tool_call_id"].(string) - } - openAiChatReq2.Messages = append(openAiChatReq2.Messages, tmp) - } - - // check if toolset exists - if chatReq.ToolsetId != "" { - err = setupOpenAITools(&openAiChatReq2, task, hcommon.ToolsetId(chatReq.ToolsetId)) - if err != nil { - return nil, fmt.Errorf("error setting up tools: %v", err) - } - } - - ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeChatWithTools, openAiChatReq2.Model) - if len(ep) == 0 { - return nil, fmt.Errorf("no endpoint found") - } - respData, err = hcopenai.OpenAIChatCompletion(ep[0], &openAiChatReq2) - if err != nil { - return nil, fmt.Errorf("error calling OpenAIChatCompletion: %v", err) - } - - log.Debugf("Response: %v", respData) - - for i, choice := range respData.Choices { - if choice.Index != json.Number(fmt.Sprintf("%d", i)) { - return nil, fmt.Errorf("index mismatch") - } - log.Infof("Reason: %s", choice.Reason) - if (choice.Reason == "stop" && len(choice.Message.ToolCalls) == 0) || choice.Reason == "length" { - mem.AddMessage(ChatMessage{ - Metadata: map[string]interface{}{ - "role": choice.Message.Role, - "reason": choice.Reason, - }, - Content: choice.Message.Content, - }) - finished = true - } else if choice.Reason == "tool_calls" || len(choice.Message.ToolCalls) > 0 { - mem.AddMessage(ChatMessage{ - Metadata: map[string]interface{}{ - "role": choice.Message.Role, - "tool_calls": choice.Message.ToolCalls, - "reason": choice.Reason, - }, - Content: choice.Message.Content, - }) - toolCalls := choice.Message.ToolCalls - for _, toolCall := range toolCalls { - argsStr := toolCall.Function.Arguments - // use json to unmarshal the arguments to interface{} - var args interface{} = nil - if argsStr != "" { - err := json.Unmarshal([]byte(argsStr), &args) - if err != nil { - return nil, fmt.Errorf("error unmarshalling tool call arguments: %v", err) - } - } - if toolReg, ok := GetToolByName(task, toolCall.Function.Name); ok && toolReg.Cb == "" { - // it is a built-in tool - fn := toolReg.CbBuiltIn - if fn == nil { - return nil, fmt.Errorf("built-in tool not implemented") - } - res, err := fn(inv, args) - if err != nil { - return nil, fmt.Errorf("error calling built-in tool %s: %v", toolReg.Name, err) - } - - tmp := fmt.Sprintf("%v", res) - if len(tmp) > 512 { - tmp = tmp[:509] + "..." - } - log.Infof("Builtin Tool call response: %v", tmp) - mem.AddMessage(ChatMessage{ - Metadata: map[string]interface{}{ - "role": "tool", - "tool_call_id": toolCall.Id, - }, - Content: fmt.Sprintf("%v", res), - }) - } else { - err = inv.CommMgr.SendOutgoingRPCRequestCallback(task, toolCall.Function.Name, args, func(resp *rpc.JsonRPCResponse) error { - log.Infof("External Tool call response: %v", resp) - return nil - }) - if err != nil { - return nil, fmt.Errorf("error sending tool call: %v", err) - } - } - } - } else { - return nil, fmt.Errorf("unexpected reason: %s", choice.Reason) - } - } - } - - return mem.GetMessages(), nil -} - -func innerChatCompletionNoTools(inv *hcommon.InvocationInfo, chatReq *payload.ChatCompletionRequest) ([]ChatMessage, error) { - mem := NewChatCompletionMemory() - for _, msg := range chatReq.Messages { - tmp := ChatMessage{ - Metadata: msg.Metadata, - Content: msg.Content, - } - mem.AddMessage(tmp) - } - - finished := false - var respData *hcopenai.OpenAIChatCompletionResponse - var err error - for !finished { - // create a new chat request - openAiChatReq2 := hcopenai.OpenAIChatCompletionRequest{ - Model: chatReq.Model, - Messages: []hcopenai.OpenAIChatMessage{}, - } - for _, msg := range mem.GetMessages() { - tmp := hcopenai.OpenAIChatMessage{ - Content: msg.Content, - } - if msg.Metadata["role"] != nil { - tmp.Role = msg.Metadata["role"].(string) - } - openAiChatReq2.Messages = append(openAiChatReq2.Messages, tmp) - } - - ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeChatWithTools, openAiChatReq2.Model) - if len(ep) == 0 { - return nil, fmt.Errorf("no endpoint found") - } - respData, err = hcopenai.OpenAIChatCompletion(ep[0], &openAiChatReq2) - if err != nil { - return nil, fmt.Errorf("error calling OpenAIChatCompletion: %v", err) - } - - log.Debugf("Response: %v", respData) - - for i, choice := range respData.Choices { - if choice.Index != json.Number(fmt.Sprintf("%d", i)) { - return nil, fmt.Errorf("index mismatch") - } - if choice.Reason == "stop" || choice.Reason == "length" { - mem.AddMessage(ChatMessage{ - Metadata: map[string]interface{}{ - "role": choice.Message.Role, - "reason": choice.Reason, - }, - Content: choice.Message.Content, - }) - finished = true - } else { - return nil, fmt.Errorf("unexpected reason: %s", choice.Reason) - } - } - } - - return mem.GetMessages(), nil -} diff --git a/worker/hostcalls/common/common.go b/worker/hostcalls/common/common.go deleted file mode 100644 index 7455436..0000000 --- a/worker/hostcalls/common/common.go +++ /dev/null @@ -1,271 +0,0 @@ -package common - -import ( - "encoding/json" - "fmt" - "sync" - "time" - - "github.com/lfedgeai/spear/pkg/rpc" - "github.com/lfedgeai/spear/worker/task" - log "github.com/sirupsen/logrus" -) - -type HostCall struct { - Name string - Handler HostCallHandler -} - -// invokation info -type InvocationInfo struct { - Task *task.Task - CommMgr *CommunicationManager -} - -type RespChanData struct { - Resp *rpc.JsonRPCResponse - InvInfo *InvocationInfo -} - -type ReqChanData struct { - Req *rpc.JsonRPCRequest - InvInfo *InvocationInfo -} - -// communication manager for hostcalls and guest responses -type CommunicationManager struct { - respCh chan *RespChanData // incoming responses - reqCh chan *ReqChanData // incoming requests - outCh map[task.Task]chan task.Message - - pendingRequests map[json.Number]*requestCallback - pendingRequestsMu sync.RWMutex -} - -type HostCallHandler func(inv *InvocationInfo, args interface{}) (interface{}, error) - -type HostCalls struct { - // map of hostcalls - HCMap map[string]HostCallHandler - CommMgr *CommunicationManager -} - -func NewHostCalls(commMgr *CommunicationManager) *HostCalls { - return &HostCalls{ - HCMap: make(map[string]HostCallHandler), - CommMgr: commMgr, - } -} - -func (h *HostCalls) RegisterHostCall(hc *HostCall) error { - name := hc.Name - handler := hc.Handler - log.Debugf("Registering hostcall: %s", name) - if _, ok := h.HCMap[name]; ok { - return fmt.Errorf("hostcall already registered: %s", name) - } - h.HCMap[name] = handler - return nil -} - -func (h *HostCalls) Run() { - for { - entry := h.CommMgr.GetIncomingRequest() - req := entry.Req - inv := entry.InvInfo - if handler, ok := h.HCMap[*req.Method]; ok { - result, err := handler(inv, req.Params) - if err != nil { - log.Errorf("Error executing hostcall: %v", err) - // send error response - resp := req.CreateErrorResponse(1, err.Error(), nil) - if err := h.CommMgr.SendOutgoingJsonResponse(*inv.Task, resp); err != nil { - log.Errorf("Error sending response: %v", err) - } - } else { - // send success response - log.Debugf("Hostcall success: %s", *req.Method) - resp := req.CreateSuccessResponse(result) - if err := h.CommMgr.SendOutgoingJsonResponse(*inv.Task, resp); err != nil { - log.Errorf("Error sending response: %v", err) - } - } - } else { - log.Errorf("Hostcall not found: %s", *req.Method) - // send error response - resp := req.CreateErrorResponse(2, "method not found", nil) - if err := h.CommMgr.SendOutgoingJsonResponse(*inv.Task, resp); err != nil { - log.Errorf("Error sending response: %v", err) - } - } - } -} - -func NewCommunicationManager() *CommunicationManager { - return &CommunicationManager{ - respCh: make(chan *RespChanData, 1024), - reqCh: make(chan *ReqChanData, 1024), - outCh: make(map[task.Task]chan task.Message), - - pendingRequests: make(map[json.Number]*requestCallback), - pendingRequestsMu: sync.RWMutex{}, - } -} - -func (c *CommunicationManager) InstallToTask(t task.Task) error { - if t == nil { - log.Errorf("task is nil") - return fmt.Errorf("task is nil") - } - - // check in and out channel - in, out, err := t.CommChannels() - if err != nil { - log.Errorf("Error getting communication channels: %v", err) - return err - } - - c.outCh[t] = in - - go func() { - inv := InvocationInfo{ - Task: &t, - CommMgr: c, - } - - for msg := range out { - // process message - log.Debugf("Received message length: %d", len(msg)) - - req := &rpc.JsonRPCRequest{} - if err := req.Unmarshal([]byte(msg)); err == nil { - log.Debugf("Hostcall received request: %s", *req.Method) - c.reqCh <- &ReqChanData{ - Req: req, - InvInfo: &inv, - } - } else { - resp := &rpc.JsonRPCResponse{} - if err := resp.Unmarshal([]byte(msg)); err == nil { - go func() { - // check if it is response to a pending request - c.pendingRequestsMu.RLock() - entry, ok := c.pendingRequests[*resp.ID] - c.pendingRequestsMu.RUnlock() - if ok { - cb := entry.cb - if err := cb(resp); err != nil { - log.Errorf("Error handling response: %v", err) - } - if entry.autoClear { - c.pendingRequestsMu.Lock() - delete(c.pendingRequests, *resp.ID) - c.pendingRequestsMu.Unlock() - } - return - } - - // this is when we receive a response that is not a pending request - c.respCh <- &RespChanData{ - Resp: resp, - InvInfo: &inv, - } - }() - } else { - log.Errorf("Invalid request: %v. Len %d, Data: %s", err, len(msg), string(msg)) - continue - } - } - - } - }() - - return nil -} - -func (c *CommunicationManager) GetIncomingRequest() *ReqChanData { - return <-c.reqCh -} - -func (c *CommunicationManager) GetIncomingResponse() *RespChanData { - return <-c.respCh -} - -func (c *CommunicationManager) SendOutgoingJsonResponse(t task.Task, resp *rpc.JsonRPCResponse) error { - if data, err := resp.Marshal(); err == nil { - // log.Debugf("Sending data: %v", string(data)) - c.outCh[t] <- data - return nil - } else { - return fmt.Errorf("error marshalling response. err: %v, resp: %+v", err, resp) - } -} - -type ResquestCallback func(resp *rpc.JsonRPCResponse) error - -type requestCallback struct { - cb ResquestCallback - autoClear bool - ts time.Time -} - -// automatically generate the id -func (c *CommunicationManager) SendOutgoingRPCRequestCallback(t task.Task, method string, params interface{}, cb ResquestCallback) error { - req := rpc.NewJsonRPCRequest(method, params) - tmpId := json.Number(fmt.Sprintf("%d", t.NextRequestID())) - req.ID = &tmpId - return c.SendOutgoingJsonRequestCallback(t, req, cb) -} - -// users need to specify the id in the request -func (c *CommunicationManager) SendOutgoingJsonRequestCallback(t task.Task, req *rpc.JsonRPCRequest, cb func(*rpc.JsonRPCResponse) error) error { - if data, err := req.Marshal(); err == nil { - // log.Debugf("Sending data: %v", string(data)) - c.outCh[t] <- data - - // add to pending requests - c.pendingRequestsMu.Lock() - c.pendingRequests[*req.ID] = &requestCallback{ - cb: cb, - autoClear: true, - ts: time.Now(), - } - c.pendingRequestsMu.Unlock() - return nil - } - return fmt.Errorf("error marshalling request") -} - -// automatically generate the id -func (c *CommunicationManager) SendOutgoingRPCRequest(t task.Task, method string, params interface{}) (*rpc.JsonRPCResponse, error) { - req := rpc.NewJsonRPCRequest(method, params) - tmpId := json.Number(fmt.Sprintf("%d", t.NextRequestID())) - req.ID = &tmpId - return c.SendOutgoingJsonRequest(t, req) -} - -// users need to specify the id in the request -func (c *CommunicationManager) SendOutgoingJsonRequest(t task.Task, req *rpc.JsonRPCRequest) (*rpc.JsonRPCResponse, error) { - ch := make(chan *rpc.JsonRPCResponse, 1) - errCh := make(chan error, 1) - if err := c.SendOutgoingJsonRequestCallback(t, req, func(resp *rpc.JsonRPCResponse) error { - log.Debugf("SendOutgoingJsonRequestCallback received response: %s", *req.ID) - if resp.Error != nil { - errCh <- fmt.Errorf("error response: %v", resp.Error) - } else { - ch <- resp - } - return nil - }); err != nil { - return nil, err - } - - select { - case resp := <-ch: - return resp, nil - case err := <-errCh: - return nil, err - case <-time.After(rpc.ResponseTimeout): - return nil, fmt.Errorf("timeout") - } -} diff --git a/worker/hostcalls/common/tools.go b/worker/hostcalls/common/tools.go deleted file mode 100644 index e7352e6..0000000 --- a/worker/hostcalls/common/tools.go +++ /dev/null @@ -1,43 +0,0 @@ -package common - -import ( - "github.com/lfedgeai/spear/worker/task" -) - -type ToolId string -type ToolsetId string -type BuiltInToolCbFunc func(inv *InvocationInfo, args interface{}) (interface{}, error) - -type ToolParam struct { - Ptype string - Description string - Required bool -} - -type ToolRegistry struct { - Name string - Description string - Params map[string]ToolParam - Cb string - CbBuiltIn BuiltInToolCbFunc -} - -type ToolsetRegistry struct { - Description string - ToolsIds []ToolId -} - -var ( - GlobalTaskTools = map[task.TaskID]map[ToolId]ToolRegistry{} - GlobalTaskToolsets = map[task.TaskID]map[ToolsetId]ToolsetRegistry{} -) - -var BuiltinTools = []ToolRegistry{} - -func GetBuiltinTools() []ToolRegistry { - return BuiltinTools -} - -func RegisterBuiltinTool(tool ToolRegistry) { - BuiltinTools = append(BuiltinTools, tool) -} diff --git a/worker/hostcalls/embeddings.go b/worker/hostcalls/embeddings.go deleted file mode 100644 index 3992b9a..0000000 --- a/worker/hostcalls/embeddings.go +++ /dev/null @@ -1,39 +0,0 @@ -package hostcalls - -import ( - "encoding/json" - "fmt" - - "github.com/lfedgeai/spear/pkg/rpc/payload/transform" - hostcalls "github.com/lfedgeai/spear/worker/hostcalls/common" - "github.com/lfedgeai/spear/worker/hostcalls/huggingface" - openaihc "github.com/lfedgeai/spear/worker/hostcalls/openai" -) - -type EmbeddingFunc func(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) - -var ( - globalEmbeddings = map[string]EmbeddingFunc{ - "text-embedding-ada-002": openaihc.Embeddings, - "bge-large-en-v1.5": huggingface.Embeddings, - } -) - -func Embeddings(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - embeddingsReq := transform.EmbeddingsRequest{} - err = embeddingsReq.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - for k, v := range globalEmbeddings { - if k == embeddingsReq.Model { - return v(inv, args) - } - } - return nil, fmt.Errorf("embedding not found") -} diff --git a/worker/hostcalls/gen_image.go b/worker/hostcalls/gen_image.go deleted file mode 100644 index 37b2b5c..0000000 --- a/worker/hostcalls/gen_image.go +++ /dev/null @@ -1,52 +0,0 @@ -package hostcalls - -import ( - "encoding/json" - "fmt" - - "github.com/lfedgeai/spear/pkg/rpc/payload/transform" - "github.com/lfedgeai/spear/worker/hostcalls/common" - hostcalls "github.com/lfedgeai/spear/worker/hostcalls/common" - oai "github.com/lfedgeai/spear/worker/hostcalls/openai" -) - -func TextToImage(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - // right now we just call openai TextToSpeech - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - - req := &transform.ImageGenerationRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - req2 := &oai.OpenAIImageGenerationRequest{ - Model: req.Model, - Prompt: req.Prompt, - ResponseFormat: req.ResponseFormat, - } - ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeImageGeneration, req2.Model) - if len(ep) == 0 { - return nil, fmt.Errorf("error getting endpoint for model %s", req2.Model) - } - res, err := oai.OpenAIImageGeneration(ep[0], req2) - if err != nil { - return nil, fmt.Errorf("error calling openai TextToImage: %v", err) - } - - res2 := &transform.ImageGenerationResponse{ - Created: res.Created, - } - for _, obj := range res.Data { - res2.Data = append(res2.Data, transform.ImageObject{ - Url: obj.Url, - B64Json: obj.B64Json, - RevisedPrompt: obj.RevisedPrompt, - }) - } - - return res2, nil -} diff --git a/worker/hostcalls/hc_entries.go b/worker/hostcalls/hc_entries.go deleted file mode 100644 index b9696fd..0000000 --- a/worker/hostcalls/hc_entries.go +++ /dev/null @@ -1,98 +0,0 @@ -package hostcalls - -import ( - "github.com/lfedgeai/spear/pkg/rpc/payload" - hostcalls "github.com/lfedgeai/spear/worker/hostcalls/common" -) - -var Hostcalls = []*hostcalls.HostCall{ - { - Name: payload.HostCallTransform, - Handler: Transform, - }, - { - Name: payload.HostCallTransformConfig, - Handler: TransformConfig, - }, - { - Name: payload.HostCallToolNew, - Handler: NewTool, - }, - { - Name: payload.HostCallToolsetNew, - Handler: NewToolset, - }, - { - Name: payload.HostCallToolsetInstallBuiltins, - Handler: ToolsetInstallBuiltins, - }, - // // chat operations - // { - // Name: transform.HostCallChatCompletion, - // Handler: ChatCompletionWithTools, - // }, - // // text to speech operations - // { - // Name: openai.HostCallTextToSpeech, - // Handler: openaihc.TextToSpeech, - // }, - // // image generation operations - // { - // Name: openai.HostCallImageGeneration, - // Handler: openaihc.ImageGeneration, - // }, - // // embeddings operations - // { - // Name: openai.HostCallEmbeddings, - // Handler: openaihc.Embeddings, - // }, - // vector store operations - { - Name: payload.HostCallVectorStoreCreate, - Handler: VectorStoreCreate, - }, - { - Name: payload.HostCallVectorStoreDelete, - Handler: VectorStoreDelete, - }, - { - Name: payload.HostCallVectorStoreInsert, - Handler: VectorStoreInsert, - }, - { - Name: payload.HostCallVectorStoreSearch, - Handler: VectorStoreSearch, - }, - // message passing operations - { - Name: payload.HostCallMessagePassingRegister, - Handler: MessagePassingRegister, - }, - { - Name: payload.HostCallMessagePassingUnregister, - Handler: MessagePassingUnregister, - }, - { - Name: payload.HostCallMessagePassingLookup, - Handler: MessagePassingLookup, - }, - { - Name: payload.HostCallMessagePassingSend, - Handler: MessagePassingSend, - }, - // input operations - { - Name: payload.HostCallInput, - Handler: Input, - }, - // speak operations - { - Name: payload.HostCallSpeak, - Handler: Speak, - }, - // record operations - { - Name: payload.HostCallRecord, - Handler: Record, - }, -} diff --git a/worker/hostcalls/huggingface/huggingface_hc.go b/worker/hostcalls/huggingface/huggingface_hc.go deleted file mode 100644 index 7aea126..0000000 --- a/worker/hostcalls/huggingface/huggingface_hc.go +++ /dev/null @@ -1,88 +0,0 @@ -package huggingface - -import ( - "bytes" - "encoding/json" - "fmt" - "os" - - "github.com/lfedgeai/spear/pkg/net" - "github.com/lfedgeai/spear/pkg/rpc/payload/transform" - hostcalls "github.com/lfedgeai/spear/worker/hostcalls/common" - log "github.com/sirupsen/logrus" -) - -type HuggingFaceEmbeddingsRequest struct { - Inputs string `json:"inputs"` -} - -func Embeddings(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - // verify the type of args is EmbeddingsRequest - // use json marshal and unmarshal to verify the type - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - embeddingsReq := transform.EmbeddingsRequest{} - err = embeddingsReq.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - embeddingsReq2 := HuggingFaceEmbeddingsRequest{ - Inputs: embeddingsReq.Input, - } - - jsonBytes, err = json.Marshal(embeddingsReq2) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - - // make sure HUGGINGFACEHUB_API_TOKEN is there - if os.Getenv("HUGGINGFACEHUB_API_TOKEN") == "" { - return nil, fmt.Errorf("error getting huggingface api token") - } - apiKey := os.Getenv("HUGGINGFACEHUB_API_TOKEN") - - log.Debugf("Embeddings Request: %s", string(jsonBytes)) - res, err := net.SendRequest( - "https://api-inference.huggingface.co/models/BAAI/bge-large-en-v1.5", - bytes.NewBuffer(jsonBytes), - net.ContentTypeJSON, - apiKey, - ) - - if err != nil { - return nil, fmt.Errorf("error sending request: %v", err) - } - - listRes := []float64{} - if err := json.Unmarshal(res, &listRes); err != nil { - // might be something like - // {"error":"Model BAAI/bge-large-en-v1.5 is currently loading","estimated_time":53.62286376953125} - tmp := map[string]interface{}{} - if err := json.Unmarshal(res, &tmp); err != nil { - log.Errorf("Error unmarshalling data: %v", res) - return nil, fmt.Errorf("error unmarshalling data. %v", err) - } - if _, ok := tmp["error"]; ok { - log.Warnf("Model is not ready yet: %v", tmp) - listRes = []float64{1.1, 2.2, 3.3} - } else { - log.Errorf("Error unmarshalling data: %v", res) - return nil, fmt.Errorf("error unmarshalling data. %v", err) - } - } - respData := transform.EmbeddingsResponse{} - respData.Data = []transform.EmbeddingObject{ - { - Object: "embedding", - Embedding: listRes, - Index: 0, - }, - } - respData.Model = "bge-large-en-v1.5" - - // return the response - return respData, nil -} diff --git a/worker/hostcalls/msgpassing.go b/worker/hostcalls/msgpassing.go deleted file mode 100644 index 92a4078..0000000 --- a/worker/hostcalls/msgpassing.go +++ /dev/null @@ -1,140 +0,0 @@ -package hostcalls - -import ( - "encoding/json" - "fmt" - - "math/rand" - - "github.com/lfedgeai/spear/pkg/rpc/payload" - hostcalls "github.com/lfedgeai/spear/worker/hostcalls/common" - log "github.com/sirupsen/logrus" -) - -type MessagePassingRegistry struct { - name string - method string - pendingData chan interface{} - id uint64 -} - -var ( - globalRegisteredMessagePassing = map[string]MessagePassingRegistry{} -) - -func MessagePassingRegister(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("Executing hostcall \"%s\" with args %v for task %s", - payload.HostCallMessagePassingRegister, args, task.ID()) - - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - req := payload.MessagePassingRegisterRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - globalRegisteredMessagePassing[req.Name] = MessagePassingRegistry{ - name: req.Name, - method: req.Method, - pendingData: make(chan interface{}), - id: rand.Uint64(), - } - - return &payload.MessagePassingRegisterResponse{ - MsgPassingId: globalRegisteredMessagePassing[req.Name].id, - }, nil -} - -func MessagePassingUnregister(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("Executing hostcall \"%s\" with args %v for task %s", - payload.HostCallMessagePassingUnregister, args, task.ID()) - - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - req := payload.MessagePassingUnregisterRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - found := false - - for k, v := range globalRegisteredMessagePassing { - if v.id == req.MsgPassingId { - delete(globalRegisteredMessagePassing, k) - found = true - break - } - } - - if !found { - return nil, fmt.Errorf("message passing id not found") - } - - return &payload.MessagePassingUnregisterResponse{ - MsgPassingId: req.MsgPassingId, - }, nil -} - -func MessagePassingLookup(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("Executing hostcall \"%s\" with args %v for task %s", - payload.HostCallMessagePassingLookup, args, task.ID()) - - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - req := payload.MessagePassingLookupRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - if v, ok := globalRegisteredMessagePassing[req.Name]; ok { - return &payload.MessagePassingLookupResponse{ - MsgPassingId: v.id, - }, nil - } - - return nil, fmt.Errorf("message passing name not found") -} - -func MessagePassingSend(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("Executing hostcall \"%s\" with args %v for task %s", - payload.HostCallMessagePassingSend, args, task.ID()) - - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - req := payload.MessagePassingSendRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - for _, v := range globalRegisteredMessagePassing { - if v.id == req.MsgPassingId { - // send without blocking - select { - case v.pendingData <- req.Data: - return &payload.MessagePassingSendResponse{ - MsgPassingId: req.MsgPassingId, - }, nil - default: - return nil, fmt.Errorf("message passing channel full") - } - } - } - - return nil, fmt.Errorf("message passing id not found") -} diff --git a/worker/hostcalls/stt.go b/worker/hostcalls/stt.go deleted file mode 100644 index a3216ae..0000000 --- a/worker/hostcalls/stt.go +++ /dev/null @@ -1,40 +0,0 @@ -package hostcalls - -import ( - "encoding/json" - "fmt" - - "github.com/lfedgeai/spear/pkg/rpc/payload/transform" - "github.com/lfedgeai/spear/worker/hostcalls/common" - hostcalls "github.com/lfedgeai/spear/worker/hostcalls/common" - oai "github.com/lfedgeai/spear/worker/hostcalls/openai" -) - -func SpeechToText(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - // right now we just call openai SpeechToText - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - - req := &transform.SpeechToTextRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - req2 := &oai.OpenAISpeechToTextRequest{ - Model: req.Model, - Audio: req.Audio, - } - ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeSpeechToText, req2.Model) - if len(ep) == 0 { - return nil, fmt.Errorf("error getting endpoint for model %s", req2.Model) - } - res, err := oai.OpenAISpeechToText(ep[0], req2) - if err != nil { - return nil, fmt.Errorf("error calling openai SpeechToText: %v", err) - } - - return res, nil -} diff --git a/worker/hostcalls/tools.go b/worker/hostcalls/tools.go deleted file mode 100644 index e377442..0000000 --- a/worker/hostcalls/tools.go +++ /dev/null @@ -1,186 +0,0 @@ -package hostcalls - -import ( - "encoding/json" - "fmt" - - "github.com/google/uuid" - - "github.com/lfedgeai/spear/pkg/rpc/payload" - hcommon "github.com/lfedgeai/spear/worker/hostcalls/common" - "github.com/lfedgeai/spear/worker/task" - log "github.com/sirupsen/logrus" -) - -func NewTool(inv *hcommon.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("NewTool called from task [%s]", task.Name()) - taskId := task.ID() - - // args is a NewToolRequest - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, err - } - - req := &payload.NewToolRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, err - } - - // check if task exists - if _, ok := hcommon.GlobalTaskTools[taskId]; !ok { - hcommon.GlobalTaskTools[taskId] = make(map[hcommon.ToolId]hcommon.ToolRegistry) - } - - tid := hcommon.ToolId(uuid.New().String()) - // create tool - hcommon.GlobalTaskTools[taskId][tid] = hcommon.ToolRegistry{ - Name: req.Name, - Description: req.Description, - Params: make(map[string]hcommon.ToolParam), - Cb: req.Cb, - CbBuiltIn: nil, - } - - for _, param := range req.Params { - hcommon.GlobalTaskTools[taskId][tid].Params[param.Name] = hcommon.ToolParam{ - Ptype: param.Type, - Description: param.Description, - Required: param.Required, - } - } - - return &payload.NewToolResponse{ - Tid: string(tid), - }, nil -} - -func NewToolset(inv *hcommon.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("NewToolset called from task [%s]", task.Name()) - - // args is a NewToolsetRequest - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, err - } - - req := &payload.NewToolsetRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, err - } - - // check if task exists - taskId := task.ID() - if _, ok := hcommon.GlobalTaskToolsets[taskId]; !ok { - hcommon.GlobalTaskToolsets[taskId] = make(map[hcommon.ToolsetId]hcommon.ToolsetRegistry) - } - - tids := []hcommon.ToolId{} - for _, tid := range req.ToolIds { - // make sure tool exists - if _, ok := hcommon.GlobalTaskTools[taskId][hcommon.ToolId(tid)]; !ok { - return nil, fmt.Errorf("tool with id %s does not exist", tid) - } - tids = append(tids, hcommon.ToolId(tid)) - } - - tsid := hcommon.ToolsetId(uuid.New().String()) - // create toolset - hcommon.GlobalTaskToolsets[taskId][tsid] = hcommon.ToolsetRegistry{ - Description: req.Description, - ToolsIds: tids, - } - - return &payload.NewToolsetResponse{ - Tsid: string(tsid), - }, nil -} - -func ToolsetInstallBuiltins(inv *hcommon.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("ToolsetInstallBuiltins called from task [%s]", task.Name()) - - // args is a ToolsetInstallBuiltinsRequest - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, err - } - - req := &payload.ToolsetInstallBuiltinsRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, err - } - - // check if task exists - taskId := task.ID() - if _, ok := hcommon.GlobalTaskToolsets[taskId]; !ok { - hcommon.GlobalTaskToolsets[taskId] = make(map[hcommon.ToolsetId]hcommon.ToolsetRegistry) - } - - // install BuiltinTools to task - tids := []hcommon.ToolId{} - for _, tool := range hcommon.GetBuiltinTools() { - tid := hcommon.ToolId(uuid.New().String()) - if _, ok := hcommon.GlobalTaskTools[taskId]; !ok { - hcommon.GlobalTaskTools[taskId] = make(map[hcommon.ToolId]hcommon.ToolRegistry) - } - hcommon.GlobalTaskTools[taskId][tid] = tool - tids = append(tids, tid) - } - - tsid := req.Tsid - // add BuiltinTools to toolset - if toolset, ok := hcommon.GlobalTaskToolsets[taskId][hcommon.ToolsetId(tsid)]; !ok { - return nil, fmt.Errorf("toolset with id %s does not exist", tsid) - } else { - toolset.ToolsIds = append(toolset.ToolsIds, tids...) - hcommon.GlobalTaskToolsets[taskId][hcommon.ToolsetId(tsid)] = toolset - } - - return &payload.ToolsetInstallBuiltinsResponse{ - Tsid: tsid, - }, nil -} - -func GetToolset(task task.Task, tsid hcommon.ToolsetId) (*hcommon.ToolsetRegistry, bool) { - taskId := task.ID() - if _, ok := hcommon.GlobalTaskToolsets[taskId]; !ok { - return nil, false - } - if _, ok := hcommon.GlobalTaskToolsets[taskId][tsid]; !ok { - return nil, false - } - res := hcommon.GlobalTaskToolsets[taskId][tsid] - return &res, true -} - -func GetToolById(task task.Task, tid hcommon.ToolId) (*hcommon.ToolRegistry, bool) { - taskId := task.ID() - if _, ok := hcommon.GlobalTaskTools[taskId]; !ok { - return nil, false - } - if _, ok := hcommon.GlobalTaskTools[taskId][tid]; !ok { - return nil, false - } - res := hcommon.GlobalTaskTools[taskId][tid] - return &res, true -} - -func GetToolByName(task task.Task, name string) (*hcommon.ToolRegistry, bool) { - taskId := task.ID() - if _, ok := hcommon.GlobalTaskTools[taskId]; !ok { - return nil, false - } - for tid, tool := range hcommon.GlobalTaskTools[taskId] { - if tool.Name == name { - toolreg := hcommon.GlobalTaskTools[taskId][tid] - return &toolreg, true - } - } - return nil, false -} diff --git a/worker/hostcalls/transform.go b/worker/hostcalls/transform.go deleted file mode 100644 index 5fdeea1..0000000 --- a/worker/hostcalls/transform.go +++ /dev/null @@ -1,189 +0,0 @@ -package hostcalls - -import ( - "encoding/json" - "fmt" - - "github.com/lfedgeai/spear/pkg/rpc/payload" - hostcalls "github.com/lfedgeai/spear/worker/hostcalls/common" - t "github.com/lfedgeai/spear/worker/task" - log "github.com/sirupsen/logrus" -) - -type TransformRegistry struct { - name string - inputTypes []payload.TransformType - outputTypes []payload.TransformType - operations []payload.TransformOperation - cb func(*hostcalls.InvocationInfo, interface{}) (interface{}, error) -} - -var ( - globalRegisteredTransform = []TransformRegistry{ - { - name: "chat_with_tools", - inputTypes: []payload.TransformType{payload.TransformTypeText}, - outputTypes: []payload.TransformType{payload.TransformTypeText}, - operations: []payload.TransformOperation{ - payload.TransformOperationLLM, - payload.TransformOperationTools, - }, - cb: ChatCompletionWithTools, - }, - { - name: "chat", - inputTypes: []payload.TransformType{payload.TransformTypeText}, - outputTypes: []payload.TransformType{payload.TransformTypeText}, - operations: []payload.TransformOperation{payload.TransformOperationLLM}, - cb: ChatCompletionNoTools, - }, - { - name: "embeddings", - inputTypes: []payload.TransformType{payload.TransformTypeText}, - outputTypes: []payload.TransformType{payload.TransformTypeVector}, - operations: []payload.TransformOperation{payload.TransformOperationEmbeddings}, - cb: Embeddings, - }, - { - name: "text-to-speech", - inputTypes: []payload.TransformType{payload.TransformTypeText}, - outputTypes: []payload.TransformType{payload.TransformTypeAudio}, - operations: []payload.TransformOperation{payload.TransformOperationTextToSpeech}, - cb: TextToSpeech, - }, - { - name: "speech-to-text", - inputTypes: []payload.TransformType{payload.TransformTypeAudio}, - outputTypes: []payload.TransformType{payload.TransformTypeText}, - operations: []payload.TransformOperation{payload.TransformOperationSpeechToText}, - cb: SpeechToText, - }, - { - name: "text-to-image", - inputTypes: []payload.TransformType{payload.TransformTypeText}, - outputTypes: []payload.TransformType{payload.TransformTypeImage}, - operations: []payload.TransformOperation{payload.TransformOperationTextToImage}, - cb: TextToImage, - }, - } -) - -func isSubSetTransform(a, b []payload.TransformType) bool { - for _, t1 := range a { - found := false - for _, t2 := range b { - if t1 == t2 { - found = true - break - } - } - if !found { - return false - } - } - return true -} - -func isSubsetOperation(a, b []payload.TransformOperation) bool { - for _, t1 := range a { - found := false - for _, t2 := range b { - if t1 == t2 { - found = true - break - } - } - if !found { - return false - } - } - return true -} - -func TransformConfig(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("Executing hostcall \"%s\" with args %v for task %s", - payload.HostCallTransformConfig, args, task.ID()) - // convert args to TransformConfigRequest - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - - req := &payload.TransformConfigRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - if req.Reset { - task.SetVar(t.TVTest, nil) - return &payload.TransformConfigResponse{ - Result: "success", - }, nil - } - - if req.Test != "" { - task.SetVar(t.TVTest, req.Test) - } - - return &payload.TransformConfigResponse{ - Result: "success", - }, nil -} - -func Transform(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("Executing hostcall \"%s\" with args %v for task %s", - payload.HostCallTransform, args, task.ID()) - // convert args to TransformRequest - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - - req := &payload.TransformRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - var candid *TransformRegistry - - // find the transform registry - for _, reg := range globalRegisteredTransform { - if isSubSetTransform(req.InputTypes, reg.inputTypes) && - isSubSetTransform(req.OutputTypes, reg.outputTypes) && - isSubsetOperation(req.Operations, reg.operations) { - if candid != nil { - if len(reg.inputTypes) <= len(candid.inputTypes) && - len(reg.outputTypes) <= len(candid.outputTypes) && - len(reg.operations) <= len(candid.operations) { - candid = ® - } - } else { - candid = ® - } - } - } - - if candid != nil { - log.Infof("Using transform registry %s", candid.name) - res, err := candid.cb(inv, req.Params) - if err != nil { - return nil, fmt.Errorf("error calling %s: %v", candid.name, err) - } - - transResp := &payload.TransformResponse{ - Results: []payload.TransformResult{ - { - Type: candid.outputTypes[0], - Data: res, - }, - }, - } - return transResp, nil - } - - return nil, fmt.Errorf("hostcall \"%s\" not implemented", payload.HostCallTransform) -} diff --git a/worker/hostcalls/tts.go b/worker/hostcalls/tts.go deleted file mode 100644 index 9af6338..0000000 --- a/worker/hostcalls/tts.go +++ /dev/null @@ -1,37 +0,0 @@ -package hostcalls - -import ( - "fmt" - - "github.com/lfedgeai/spear/pkg/rpc/payload/transform" - "github.com/lfedgeai/spear/pkg/utils" - "github.com/lfedgeai/spear/worker/hostcalls/common" - hostcalls "github.com/lfedgeai/spear/worker/hostcalls/common" - oai "github.com/lfedgeai/spear/worker/hostcalls/openai" -) - -func TextToSpeech(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - // right now we just call openai TextToSpeech - req := &transform.TextToSpeechRequest{} - err := utils.InterfaceToType(&req, args) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - req2 := &oai.OpenAITextToSpeechRequest{ - Model: req.Model, - Input: req.Input, - Voice: req.Voice, - Format: req.Format, - } - ep := common.GetAPIEndpointInfo(common.OpenAIFunctionTypeTextToSpeech, req2.Model) - if len(ep) == 0 { - return nil, fmt.Errorf("error getting endpoint for model %s", req2.Model) - } - res, err := oai.OpenAITextToSpeech(ep[0], req2) - if err != nil { - return nil, fmt.Errorf("error calling openai TextToSpeech: %v", err) - } - - return res, nil -} diff --git a/worker/hostcalls/utils.go b/worker/hostcalls/utils.go deleted file mode 100644 index 2aea9e2..0000000 --- a/worker/hostcalls/utils.go +++ /dev/null @@ -1,37 +0,0 @@ -package hostcalls - -import ( - "bytes" - "fmt" - "io" - "net/http" - "os" -) - -func sendBufferData(data *bytes.Buffer, url string) ([]byte, error) { - // create a https request to url and use data as the request body - req, err := http.NewRequest("POST", url, data) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - // get api key from environment variable - apiKey := os.Getenv("OPENAI_API_KEY") - // set the headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) - // send the request - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("error sending request: %v", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response: %v", err) - } - - return body, nil -} diff --git a/worker/hostcalls/vectorstore.go b/worker/hostcalls/vectorstore.go deleted file mode 100644 index d07b8a6..0000000 --- a/worker/hostcalls/vectorstore.go +++ /dev/null @@ -1,293 +0,0 @@ -package hostcalls - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/lfedgeai/spear/pkg/rpc/payload" - hostcalls "github.com/lfedgeai/spear/worker/hostcalls/common" - "github.com/lfedgeai/spear/worker/task" - "github.com/qdrant/go-client/qdrant" - log "github.com/sirupsen/logrus" -) - -var ( - globalVectorStoreRegistries = make(map[task.TaskID]*VectorStoreRegistry) -) - -type VectorStore struct { - Name string - NextID uint64 -} - -type VectorStoreRegistry struct { - Stores []*VectorStore - Client *qdrant.Client -} - -type VectorStoreSearchResult struct { - Vector []float32 - Data []byte -} - -func NewVectorStoreRegistry() (*VectorStoreRegistry, error) { - qdrantClient, err := qdrant.NewClient(&qdrant.Config{ - Host: "localhost", - Port: 6334, - }) - if err != nil { - log.Errorf("Error creating qdrant client: %v", err) - return nil, err - } - // list all collections - collections, err := qdrantClient.ListCollections(context.Background()) - if err != nil { - log.Errorf("Error listing collections: %v", err) - return nil, err - } - log.Infof("Collections: %v", collections) - return &VectorStoreRegistry{ - Stores: make([]*VectorStore, 0), - Client: qdrantClient, - }, nil -} - -func (r *VectorStoreRegistry) Create(storeName string, dimensions uint64) (int, error) { - log.Infof("Creating vector store with name %s", storeName) - // duplicated store is not allowed - for i, store := range r.Stores { - if store.Name == storeName { - return i, fmt.Errorf("store with name %s already exists", storeName) - } - } - - // create the vector store in qdrant - err := r.Client.CreateCollection(context.Background(), &qdrant.CreateCollection{ - CollectionName: storeName, - VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ - Size: dimensions, - Distance: qdrant.Distance_Cosine, - }), - }) - if err != nil { - return -1, fmt.Errorf("error creating collection: %v", err) - } - - // create a new vector store with the given name - r.Stores = append(r.Stores, &VectorStore{ - Name: storeName, - NextID: 1, - }) - - return len(r.Stores) - 1, nil -} - -func (r *VectorStoreRegistry) Delete(vid int) error { - log.Infof("Deleting vector store with id %d", vid) - // delete the vector store in qdrant - err := r.Client.DeleteCollection(context.Background(), r.Stores[vid].Name) - if err != nil { - return fmt.Errorf("error deleting collection: %v", err) - } - - // remove the vid-th vector store - r.Stores = append(r.Stores[:vid], r.Stores[vid+1:]...) - - return nil -} - -func (r *VectorStoreRegistry) Insert(vid int, vector []float32, payload []byte) error { - log.Infof("Inserting vector into vector store with id %d", vid) - // insert the vector into qdrant - opInfo, err := r.Client.Upsert(context.Background(), &qdrant.UpsertPoints{ - CollectionName: r.Stores[vid].Name, - Points: []*qdrant.PointStruct{ - { - Id: qdrant.NewIDNum(r.Stores[vid].NextID), - Payload: qdrant.NewValueMap(map[string]interface{}{ - "payload": payload, - }), - Vectors: qdrant.NewVectors(vector...), - }, - }, - }) - if err != nil { - return fmt.Errorf("error upserting points: %v", err) - } - r.Stores[vid].NextID = r.Stores[vid].NextID + 1 - log.Infof("Upsert operation info: %v", opInfo) - return nil -} - -func (r *VectorStoreRegistry) Search(vid int, vector []float32, limit uint64) ([]*VectorStoreSearchResult, error) { - log.Infof("Searching vector in vector store with vid %d and vector %v", vid, vector) - // search the vector in qdrant - result, err := r.Client.Query(context.Background(), &qdrant.QueryPoints{ - CollectionName: r.Stores[vid].Name, - Query: qdrant.NewQuery(vector...), - Limit: &limit, - }) - if err != nil { - return nil, fmt.Errorf("error querying points: %v", err) - } - ret := make([]*VectorStoreSearchResult, len(result)) - for i, res := range result { - if res.Vectors == nil { - log.Infof(fmt.Sprintf("Vector is nil: %v", res)) - ret[i] = &VectorStoreSearchResult{ - Vector: nil, - Data: []byte(res.Payload["payload"].String()), - } - } else { - ret[i] = &VectorStoreSearchResult{ - Vector: res.Vectors.GetVector().Data, - Data: []byte(res.Payload["payload"].String()), - } - } - } - log.Infof("Search result: %+v", ret) - return ret, nil -} - -func VectorStoreCreate(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("Executing hostcall \"%s\" with args %v", payload.HostCallVectorStoreCreate, args) - // verify the type of args is string - // use json marshal and unmarshal to verify the type - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - req := payload.VectorStoreCreateRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - log.Infof("VectorStoreCreate Request: %v", req) - // create a new vector store - if _, ok := globalVectorStoreRegistries[task.ID()]; !ok { - val, err := NewVectorStoreRegistry() - if err != nil { - return nil, fmt.Errorf("error creating vector store registry: %v", err) - } - globalVectorStoreRegistries[task.ID()] = val - } - - vid, err := globalVectorStoreRegistries[task.ID()].Create(req.Name, req.Dimentions) - if err != nil { - return nil, fmt.Errorf("error creating vector store: %v", err) - } - - // return the response - return &payload.VectorStoreCreateResponse{ - VID: vid, - }, nil -} - -func VectorStoreDelete(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("Executing hostcall \"%s\" with args %v", payload.HostCallVectorStoreDelete, args) - // verify the type of args is int - // use json marshal and unmarshal to verify the type - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - req := payload.VectorStoreDeleteRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - log.Infof("VectorStoreDelete Request: %v", req) - // delete the vector store - if _, ok := globalVectorStoreRegistries[task.ID()]; !ok { - return nil, fmt.Errorf("vector store registry not found") - } - - err = globalVectorStoreRegistries[task.ID()].Delete(req.VID) - if err != nil { - return nil, fmt.Errorf("error deleting vector store: %v", err) - } - - // return the response - return &payload.VectorStoreDeleteResponse{ - VID: req.VID, - }, nil -} - -func VectorStoreInsert(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("Executing hostcall \"%s\" with args %v", payload.HostCallVectorStoreInsert, args) - // verify the type of args is VectorStoreInsertRequest - // use json marshal and unmarshal to verify the type - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - req := payload.VectorStoreInsertRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - log.Infof("VectorStoreInsert Request: %s", string(jsonBytes)) - // insert the vector into the vector store - v, ok := globalVectorStoreRegistries[task.ID()] - if !ok { - return nil, fmt.Errorf("vector store registry not found") - } - - err = v.Insert(req.VID, req.Vector, req.Data) - if err != nil { - return nil, fmt.Errorf("error inserting vector: %v", err) - } - - // return the response - return payload.VectorStoreInsertResponse{ - VID: req.VID, - }, nil -} - -func VectorStoreSearch(inv *hostcalls.InvocationInfo, args interface{}) (interface{}, error) { - task := *(inv.Task) - log.Debugf("Executing hostcall \"%s\" with args %v", payload.HostCallVectorStoreSearch, args) - // verify the type of args is VectorStoreSearchRequest - // use json marshal and unmarshal to verify the type - jsonBytes, err := json.Marshal(args) - if err != nil { - return nil, fmt.Errorf("error marshalling args: %v", err) - } - req := payload.VectorStoreSearchRequest{} - err = req.Unmarshal(jsonBytes) - if err != nil { - return nil, fmt.Errorf("error unmarshalling args: %v", err) - } - - log.Infof("VectorStoreSearch Request: %s", string(jsonBytes)) - // search the vector in the vector store - v, ok := globalVectorStoreRegistries[task.ID()] - if !ok { - return nil, fmt.Errorf("vector store registry not found") - } - - result, err := v.Search(req.VID, req.Vector, req.Limit) - if err != nil { - return nil, fmt.Errorf("error searching vector: %v", err) - } - - // return the response - res := payload.VectorStoreSearchResponse{ - VID: req.VID, - Entries: make([]payload.VectorStoreSearchResponseEntry, len(result)), - } - for i, r := range result { - res.Entries[i] = payload.VectorStoreSearchResponseEntry{ - Vector: r.Vector, - Data: r.Data, - } - } - return res, nil -} diff --git a/workload/docker/go/dummy/Dockerfile b/workload/docker/go/dummy/Dockerfile deleted file mode 100644 index be720aa..0000000 --- a/workload/docker/go/dummy/Dockerfile +++ /dev/null @@ -1,12 +0,0 @@ -FROM ubuntu:24.04 - -# update the package lists -RUN apt-get update -y -RUN apt-get install -y netcat-openbsd net-tools - -WORKDIR / - -# COPY ./scripts/start.sh /start -COPY ./bin/start /start - -CMD ["sleep", "infinity"] diff --git a/workload/docker/go/dummy/Makefile b/workload/docker/go/dummy/Makefile deleted file mode 100644 index 8c6cc1f..0000000 --- a/workload/docker/go/dummy/Makefile +++ /dev/null @@ -1,14 +0,0 @@ -.PHONY: all build clean - -CURRENT_DIR := $(shell pwd) -OUTPUT_DIR := $(shell pwd)/bin - -all: build - docker compose build - -build: - GOOS=linux go build -o $(OUTPUT_DIR)/start \ - $(CURRENT_DIR)/src/start.go - -clean: - rm -rf $(OUTPUT_DIR) diff --git a/workload/docker/go/dummy/compose.yaml b/workload/docker/go/dummy/compose.yaml deleted file mode 100644 index e3c0aa6..0000000 --- a/workload/docker/go/dummy/compose.yaml +++ /dev/null @@ -1,8 +0,0 @@ -services: - dummy: - image: dummy - build: - context: . - dockerfile: Dockerfile - environment: - - DUMMY_ENV=1 \ No newline at end of file diff --git a/workload/docker/go/dummy/scripts/start.sh b/workload/docker/go/dummy/scripts/start.sh deleted file mode 100755 index 8656ab9..0000000 --- a/workload/docker/go/dummy/scripts/start.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/sh - -# a json rpc example - -for i in `seq 1 2`; -do - echo "{\"jsonrpc\":\"2.0\",\"method\":\"eth_blockNumber\",\"params\":[],\"id\":83}" - sleep 10 -done diff --git a/workload/docker/go/dummy/src/start.go b/workload/docker/go/dummy/src/start.go deleted file mode 100644 index 758348c..0000000 --- a/workload/docker/go/dummy/src/start.go +++ /dev/null @@ -1,158 +0,0 @@ -package main - -import ( - "encoding/binary" - "fmt" - "io" - "net" - "os" - "strconv" - "time" - - // flags support - - "github.com/lfedgeai/spear/pkg/rpc" - "github.com/lfedgeai/spear/pkg/rpc/payload" - - // logrus - log "github.com/sirupsen/logrus" -) - -var hdl *rpc.GuestRPCManager -var hostaddr string -var secret string - -var input io.Reader -var output io.Writer - -// parse arguments -func init() { - // get hostaddr and secret from environment variables - hostaddr = os.Getenv("SERVICE_ADDR") - secret = os.Getenv("SECRET") - - log.Debugf("Connecting to host at %s", hostaddr) - // create tcp connection to host - conn, err := net.Dial("tcp", hostaddr) - if err != nil { - log.Fatalf("failed to connect to host: %v", err) - } - - // sending the secret - // convert secret string to int64 - secretInt, err := strconv.ParseInt(secret, 10, 64) - if err != nil { - log.Fatalf("failed to convert secret to int64: %v", err) - } - // convert int64 to little endian byte array - secretBytes := make([]byte, 8) - binary.LittleEndian.PutUint64(secretBytes, uint64(secretInt)) - // write secret to connection - _, err = conn.Write(secretBytes) - if err != nil { - log.Fatalf("failed to write secret to connection: %v", err) - } - - // create input and output files from connection - input = conn - output = conn -} - -func main() { - hdl := rpc.NewGuestRPCManager( - func(req *rpc.JsonRPCRequest) (*rpc.JsonRPCResponse, error) { - log.Debugf("Request: %s", *req.Method) - return rpc.NewJsonRPCResponse(*req.ID, nil), nil - }, - nil, - ) - hdl.SetInput(input) - hdl.SetOutput(output) - - hdl.RegisterIncomingHandler("handle", func(args interface{}) (interface{}, error) { - log.Infof("Incoming request: %v", args) - return "ok", nil - }) - go hdl.Run() - - resp, err := rpc.ChatCompletion(hdl, "gpt-4o", []payload.ChatMessageV2{ - { - Metadata: map[string]interface{}{ - "role": "system", - }, - Content: "Hello, how can I help you?", - }, - { - Metadata: map[string]interface{}{ - "role": "user", - }, - Content: "I need help with my computer", - }, - }, "") - if err != nil { - panic(err) - } - log.Infof("Response: %v", resp) - - _, err = rpc.Embeddings(hdl, "text-embedding-ada-002", //"bge-large-en-v1.5" - "The food was delicious and the waiter...") - if err != nil { - panic(err) - } - - randName := fmt.Sprintf("vdb-%d", time.Now().UnixNano()) - - // vector store ops - req3 := rpc.NewJsonRPCRequest(payload.HostCallVectorStoreCreate, payload.VectorStoreCreateRequest{ - Name: randName, - Dimentions: 4, - }) - if resp, err := hdl.SendJsonRequest(req3); err != nil { - panic(err) - } else { - log.Infof("Response: %v", resp) - } - - data := [][]float32{ - {0.05, 0.61, 0.76, 0.74}, - {0.19, 0.81, 0.75, 0.11}, - {0.36, 0.55, 0.47, 0.94}, - {0.18, 0.01, 0.85, 0.80}, - {0.24, 0.18, 0.22, 0.44}, - {0.35, 0.08, 0.11, 0.44}, - } - - for _, v := range data { - req3_5 := rpc.NewJsonRPCRequest(payload.HostCallVectorStoreInsert, payload.VectorStoreInsertRequest{ - VID: 0, - Vector: v, - Data: []byte("test data"), - }) - if resp, err := hdl.SendJsonRequest(req3_5); err != nil { - panic(err) - } else { - log.Infof("Response: %.1024v", resp) - } - } - - req3_6 := rpc.NewJsonRPCRequest(payload.HostCallVectorStoreSearch, payload.VectorStoreSearchRequest{ - VID: 0, - Vector: []float32{0.2, 0.1, 0.9, 0.7}, - Limit: 1, - }) - if resp, err := hdl.SendJsonRequest(req3_6); err != nil { - panic(err) - } else { - log.Infof("Response: %v", resp) - } - - // delete vector store - req4 := rpc.NewJsonRPCRequest(payload.HostCallVectorStoreDelete, payload.VectorStoreDeleteRequest{ - VID: 0, - }) - if resp, err := hdl.SendJsonRequest(req4); err != nil { - panic(err) - } else { - log.Infof("Response: %v", resp) - } -} diff --git a/workload/docker/go/gen_image/Makefile b/workload/docker/go/gen_image/Makefile index dbe6095..84a4797 100644 --- a/workload/docker/go/gen_image/Makefile +++ b/workload/docker/go/gen_image/Makefile @@ -5,19 +5,15 @@ OUTPUT_DIR := $(shell pwd)/bin PROJ_NAME := $(shell basename $(CURRENT_DIR)) all: build - docker compose build + docker-compose build --no-cache start: GOOS=linux go build -o $(OUTPUT_DIR)/start \ $(CURRENT_DIR)/src/start.go -build: start demo - -demo: - go build -o $(OUTPUT_DIR)/demo \ - $(CURRENT_DIR)/src/demo.go +build: start clean: rm -rf $(OUTPUT_DIR) -.PHONY: all build clean demo +.PHONY: all build clean diff --git a/workload/docker/go/gen_image/compose.yaml b/workload/docker/go/gen_image/compose.yaml index 170b8b8..e203597 100644 --- a/workload/docker/go/gen_image/compose.yaml +++ b/workload/docker/go/gen_image/compose.yaml @@ -1,6 +1,6 @@ services: - dummy: - image: gen_image + gen_image: + image: gen_image:latest build: context: . dockerfile: Dockerfile \ No newline at end of file diff --git a/workload/docker/go/gen_image/src/demo.go b/workload/docker/go/gen_image/src/demo.go deleted file mode 100644 index febd6e4..0000000 --- a/workload/docker/go/gen_image/src/demo.go +++ /dev/null @@ -1,122 +0,0 @@ -package main - -import ( - "bufio" - "bytes" - "encoding/base64" - "fmt" - "net/http" - "os" - "os/exec" - "runtime" - - "github.com/lfedgeai/spear/pkg/rpc/payload/transform" - "github.com/lfedgeai/spear/pkg/tools/docker" - log "github.com/sirupsen/logrus" -) - -func init() { - log.SetLevel(log.DebugLevel) -} - -func main() { - // get input from user - reader := bufio.NewReader(os.Stdin) - fmt.Print("Image Description: ") - - input, err := reader.ReadString('\n') - if err != nil { - panic("reader.ReadString failed: " + err.Error()) - } - - // setup test environment - s := docker.NewTestSetup() - defer s.TearDown() - - // send a http request to the server and check the response - client := &http.Client{} - req, err := http.NewRequest("GET", "http://localhost:8080", bytes.NewBuffer( - []byte(input), - )) - - if err != nil { - panic("http.NewRequest failed: " + err.Error()) - } - - // add headers - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Accept", "application/json") - req.Header.Add("Spear-Func-Id", "3") - req.Header.Add("Spear-Func-Type", "1") - - // send the request - resp, err := client.Do(req) - if err != nil { - panic("client.Do failed: " + err.Error()) - } - - // print the response - buf := new(bytes.Buffer) - buf.ReadFrom(resp.Body) - - respData := buf.Bytes() - log.Debugf("Received response length: %d", len(respData)) - - var respStruct transform.ImageGenerationResponse - err = respStruct.Unmarshal(respData) - if err != nil { - panic("respStruct.Unmarshal failed: " + err.Error()) - } - - if len(respStruct.Data) != 1 { - panic("expected 1 image, got " + string(len(respStruct.Data))) - } - - // resp is a image in base64 format - // decode the image - img := make([]byte, base64.StdEncoding.DecodedLen(len(respStruct.Data[0].B64Json))) - _, err = base64.StdEncoding.Decode(img, []byte(respStruct.Data[0].B64Json)) - if err != nil { - panic("base64.StdEncoding.Decode failed: " + err.Error()) - } - - // write the image to a temp file using os.CreateTemp - file, err := os.CreateTemp("", "image-*.png") - if err != nil { - panic("os.CreateTemp failed: " + err.Error()) - } - // write the image to the file - _, err = file.Write(img) - if err != nil { - panic("file.Write failed: " + err.Error()) - } - // close the file - file.Close() - - // open the file using the default application - err = openImage(file.Name()) - if err != nil { - panic("openImage failed: " + err.Error()) - } - - // close the response body - resp.Body.Close() -} - -func openImage(filePath string) error { - var cmd *exec.Cmd - - // Determine the command based on the OS - switch runtime.GOOS { - case "linux": - cmd = exec.Command("xdg-open", filePath) - case "darwin": - cmd = exec.Command("open", filePath) - case "windows": - cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", filePath) - default: - return fmt.Errorf("unsupported platform") - } - - return cmd.Start() -} diff --git a/workload/docker/go/gen_image/src/start.go b/workload/docker/go/gen_image/src/start.go index fbfaae4..f524af3 100644 --- a/workload/docker/go/gen_image/src/start.go +++ b/workload/docker/go/gen_image/src/start.go @@ -2,20 +2,20 @@ package main import ( "encoding/binary" - "fmt" "io" "net" "os" "strconv" "time" - "github.com/lfedgeai/spear/pkg/rpc" - "github.com/lfedgeai/spear/pkg/rpc/payload/transform" + flatbuffers "github.com/google/flatbuffers/go" + spearnet "github.com/lfedgeai/spear/pkg/net" + "github.com/lfedgeai/spear/pkg/spear/proto/custom" log "github.com/sirupsen/logrus" ) -var hdl *rpc.GuestRPCManager +var hdl *spearnet.GuestRPCManager var hostaddr string var secret string @@ -56,29 +56,37 @@ func init() { } func main() { - hdl = rpc.NewGuestRPCManager(nil, nil) + hdl = spearnet.NewGuestRPCManager() hdl.SetInput(input) hdl.SetOutput(output) done := make(chan bool) - hdl.RegisterIncomingHandler("handle", func(args interface{}) (interface{}, error) { + hdl.RegisterIncomingCustomRequestHandler("handle", func(args *custom.CustomRequest) (*custom.CustomResponse, error) { defer func() { done <- true }() log.Debugf("Incoming request: %v", args) + str := string(args.ParamsStr()) // make sure args is a string - if str, ok := args.(string); ok { - resp, err := generateImage(str) - if err != nil { - log.Errorf("failed to generate image: %v", err) - return nil, err - } - log.Debugf("Generated image: %v", resp) - return resp, nil - } else { - return nil, fmt.Errorf("expected string, got %T", args) + resp, err := generateImage(str) + if err != nil { + log.Errorf("failed to generate image: %v", err) + return nil, err } + log.Debugf("Generated image: %v", resp) + + builder := flatbuffers.NewBuilder(0) + respOff := builder.CreateByteVector(resp) + + custom.CustomResponseStart(builder) + custom.CustomResponseAddData(builder, respOff) + builder.Finish(custom.CustomResponseEnd(builder)) + + respBytes := builder.FinishedBytes() + + customResp := custom.GetRootAsCustomResponse(respBytes, 0) + return customResp, nil }) go hdl.Run() @@ -87,10 +95,13 @@ func main() { time.Sleep(5 * time.Second) } -func generateImage(str string) (*transform.ImageGenerationResponse, error) { - res, err := rpc.TextToImage(hdl, "dall-e-3", str, "b64_json") - if err != nil { - return nil, fmt.Errorf("failed to generate image: %v", err) - } - return res, nil +func generateImage(str string) ([]byte, error) { + log.Infof("this test is temporarily disabled") + + // res, err := rpc.TextToImage(hdl, "dall-e-3", str, "b64_json") + // if err != nil { + // return nil, fmt.Errorf("failed to generate image: %v", err) + // } + // return res, nil + return nil, nil } diff --git a/workload/docker/go/voice_chat/Dockerfile b/workload/docker/go/voice_chat/Dockerfile deleted file mode 100644 index 681ddd6..0000000 --- a/workload/docker/go/voice_chat/Dockerfile +++ /dev/null @@ -1,8 +0,0 @@ -FROM scratch - -WORKDIR / - -# COPY ./scripts/start.sh /start -COPY ./bin/start /start - -CMD ["sleep", "infinity"] diff --git a/workload/docker/go/voice_chat/Makefile b/workload/docker/go/voice_chat/Makefile deleted file mode 100644 index dbe6095..0000000 --- a/workload/docker/go/voice_chat/Makefile +++ /dev/null @@ -1,23 +0,0 @@ -.PHONY: all build clean - -CURRENT_DIR := $(shell pwd) -OUTPUT_DIR := $(shell pwd)/bin -PROJ_NAME := $(shell basename $(CURRENT_DIR)) - -all: build - docker compose build - -start: - GOOS=linux go build -o $(OUTPUT_DIR)/start \ - $(CURRENT_DIR)/src/start.go - -build: start demo - -demo: - go build -o $(OUTPUT_DIR)/demo \ - $(CURRENT_DIR)/src/demo.go - -clean: - rm -rf $(OUTPUT_DIR) - -.PHONY: all build clean demo diff --git a/workload/docker/go/voice_chat/compose.yaml b/workload/docker/go/voice_chat/compose.yaml deleted file mode 100644 index 6121369..0000000 --- a/workload/docker/go/voice_chat/compose.yaml +++ /dev/null @@ -1,6 +0,0 @@ -services: - dummy: - image: voice_chat - build: - context: . - dockerfile: Dockerfile \ No newline at end of file diff --git a/workload/docker/go/voice_chat/src/demo.go b/workload/docker/go/voice_chat/src/demo.go deleted file mode 100644 index 7ea1f46..0000000 --- a/workload/docker/go/voice_chat/src/demo.go +++ /dev/null @@ -1,137 +0,0 @@ -package main - -import ( - "bufio" - "bytes" - "encoding/base64" - "encoding/json" - "fmt" - "net/http" - "os" - "time" - - "github.com/faiface/beep" - "github.com/faiface/beep/mp3" - "github.com/faiface/beep/speaker" - "github.com/lfedgeai/spear/pkg/tools/docker" - log "github.com/sirupsen/logrus" -) - -func init() { - log.SetLevel(log.DebugLevel) -} - -func main() { - // get input from user - reader := bufio.NewReader(os.Stdin) - fmt.Print("Message to LLM: ") - - input, err := reader.ReadString('\n') - if err != nil { - panic("reader.ReadString failed: " + err.Error()) - } - - // setup test environment - s := docker.NewTestSetup() - defer s.TearDown() - - // send a http request to the server and check the response - client := &http.Client{} - req, err := http.NewRequest("GET", "http://localhost:8080", bytes.NewBuffer( - []byte(input), - )) - - if err != nil { - panic("http.NewRequest failed: " + err.Error()) - } - - // add headers - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Accept", "application/json") - req.Header.Add("Spear-Func-Id", "2") - req.Header.Add("Spear-Func-Type", "1") - - // send the request - resp, err := client.Do(req) - if err != nil { - panic("client.Do failed: " + err.Error()) - } - - // print the response - buf := new(bytes.Buffer) - buf.ReadFrom(resp.Body) - - respData := buf.Bytes() - - log.Debugf("Received response length: %d", len(respData)) - - // umarshal the response - var data map[string]interface{} - err = json.Unmarshal([]byte(respData), &data) - if err != nil { - log.Errorf("respData: %s", respData) - panic("json.Unmarshal failed: " + err.Error()) - } - - // get the "audio" key from the response - encodedData, ok := data["audio"] - if !ok { - panic("audio key not found in response") - } - // convert from base64 to []byte - rawData, err := base64.StdEncoding.DecodeString(encodedData.(string)) - if err != nil { - panic("base64.StdEncoding.DecodeString failed: " + err.Error()) - } - - // write to a temp file - f, err := os.CreateTemp("", "audio*.mp3") - if err != nil { - panic("os.CreateTemp failed: " + err.Error()) - } - log.Debugf("Data Length: %d", len(rawData)) - // wrtie the audio data to the file - _, err = f.Write(rawData) - if err != nil { - panic("f.Write failed: " + err.Error()) - } - f.Close() - log.Debugf("Created temp file: %s", f.Name()) - - playMP3(f.Name()) - - // close the response body - resp.Body.Close() -} - -func playMP3(filePath string) error { - // Open the MP3 file - f, err := os.Open(filePath) - if err != nil { - return fmt.Errorf("could not open MP3 file: %w", err) - } - defer f.Close() - - // Decode the MP3 file - stream, format, err := mp3.Decode(f) - if err != nil { - return fmt.Errorf("could not decode MP3 file: %w", err) - } - defer stream.Close() - - // Initialize the speaker with the sample rate - err = speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)) - if err != nil { - return fmt.Errorf("could not initialize speaker: %w", err) - } - - // Play the audio stream - done := make(chan bool) - speaker.Play(beep.Seq(stream, beep.Callback(func() { - done <- true - }))) - - // Wait until the audio finishes playing - <-done - return nil -} diff --git a/workload/docker/go/voice_chat/src/start.go b/workload/docker/go/voice_chat/src/start.go deleted file mode 100644 index 06ac14a..0000000 --- a/workload/docker/go/voice_chat/src/start.go +++ /dev/null @@ -1,118 +0,0 @@ -package main - -import ( - "encoding/binary" - "fmt" - "io" - "net" - "os" - "strconv" - "time" - - "github.com/lfedgeai/spear/pkg/rpc" - "github.com/lfedgeai/spear/pkg/rpc/payload" - "github.com/lfedgeai/spear/pkg/rpc/payload/transform" - - log "github.com/sirupsen/logrus" -) - -var hdl *rpc.GuestRPCManager -var hostaddr string -var secret string - -var input io.Reader -var output io.Writer - -// parse arguments -func init() { - // get hostaddr and secret from environment variables - hostaddr = os.Getenv("SERVICE_ADDR") - secret = os.Getenv("SECRET") - - log.Debugf("Connecting to host at %s", hostaddr) - // create tcp connection to host - conn, err := net.Dial("tcp", hostaddr) - if err != nil { - log.Fatalf("failed to connect to host: %v", err) - } - - // sending the secret - // convert secret string to int64 - secretInt, err := strconv.ParseInt(secret, 10, 64) - if err != nil { - log.Fatalf("failed to convert secret to int64: %v", err) - } - // convert int64 to little endian byte array - secretBytes := make([]byte, 8) - binary.LittleEndian.PutUint64(secretBytes, uint64(secretInt)) - // write secret to connection - _, err = conn.Write(secretBytes) - if err != nil { - log.Fatalf("failed to write secret to connection: %v", err) - } - - // create input and output files from connection - input = conn - output = conn -} - -func main() { - hdl = rpc.NewGuestRPCManager(nil, nil) - hdl.SetInput(input) - hdl.SetOutput(output) - - done := make(chan bool) - - hdl.RegisterIncomingHandler("handle", func(args interface{}) (interface{}, error) { - defer func() { - done <- true - }() - log.Debugf("Incoming request: %v", args) - // make sure args is a string - if str, ok := args.(string); ok { - resp, err := getTextResponse(str) - if err != nil { - return nil, err - } - t2sResp, err := text2Speech(resp) - if err != nil { - return nil, err - } - log.Debugf("Encoded response length in task handle: %d", len(t2sResp.EncodedAudio)) - return t2sResp, nil - } else { - return nil, fmt.Errorf("expected string, got %T", args) - } - }) - go hdl.Run() - - <-done - log.Debug("Exiting") - time.Sleep(5 * time.Second) -} - -func getTextResponse(str string) (string, error) { - res, err := rpc.ChatCompletion(hdl, "gpt-4o", []payload.ChatMessageV2{ - { - Metadata: map[string]interface{}{ - "role": "user", - }, - Content: str, - }, - }, "") - if err != nil { - return "", err - } - if len(res) == 0 { - return "", fmt.Errorf("no response returned") - } - return res[len(res)-1].Content, nil -} - -func text2Speech(str string) (*transform.TextToSpeechResponse, error) { - res, err := rpc.TextToSpeech(hdl, "tts-1", "alloy", str, "mp3") - if err != nil { - return nil, err - } - return res, nil -} diff --git a/workload/docker/python/pychat/Dockerfile b/workload/docker/python/pychat/Dockerfile index ddd8ee0..b340c0e 100644 --- a/workload/docker/python/pychat/Dockerfile +++ b/workload/docker/python/pychat/Dockerfile @@ -10,7 +10,7 @@ WORKDIR / COPY ./workload/docker/python/pychat/src/start.py /start COPY ./sdk/python/dist/spear*.whl /tmp/ -RUN pip install /tmp/spear*.whl +RUN pip install --no-cache-dir /tmp/spear*.whl RUN rm /tmp/spear*.whl CMD ["sleep", "infinity"] diff --git a/workload/docker/python/pychat/Makefile b/workload/docker/python/pychat/Makefile index 492bdd3..fd4ddaa 100644 --- a/workload/docker/python/pychat/Makefile +++ b/workload/docker/python/pychat/Makefile @@ -3,7 +3,7 @@ OUTPUT_DIR := $(shell pwd)/bin PROJ_NAME := $(shell basename $(CURRENT_DIR)) all: build - docker compose build + docker-compose build --no-cache build: demo diff --git a/workload/docker/python/pychat/compose.yaml b/workload/docker/python/pychat/compose.yaml index 9a7cc1b..992c7b4 100644 --- a/workload/docker/python/pychat/compose.yaml +++ b/workload/docker/python/pychat/compose.yaml @@ -1,6 +1,6 @@ services: - dummy: - image: pychat + pychat: + image: pychat:latest build: context: ../../../../ dockerfile: ./workload/docker/python/pychat/Dockerfile \ No newline at end of file diff --git a/workload/docker/python/pyconversation-local/Dockerfile b/workload/docker/python/pyconversation-local/Dockerfile index 7e98ebd..14f51d0 100644 --- a/workload/docker/python/pyconversation-local/Dockerfile +++ b/workload/docker/python/pyconversation-local/Dockerfile @@ -4,8 +4,8 @@ FROM python:3.9.20 RUN apt-get update -y RUN apt-get install -y netcat-openbsd net-tools -RUN pip install --upgrade pip -RUN pip install py-spy +RUN pip install --no-cache-dir --upgrade pip +RUN pip install --no-cache-dir py-spy WORKDIR / @@ -13,7 +13,7 @@ WORKDIR / COPY ./workload/docker/python/pyconversation-local/src/start.py /start COPY ./sdk/python/dist/spear*.whl /tmp/ -RUN pip install /tmp/spear*.whl +RUN pip install --no-cache-dir /tmp/spear*.whl RUN rm /tmp/spear*.whl CMD ["sleep", "infinity"] diff --git a/workload/docker/python/pyconversation-local/Makefile b/workload/docker/python/pyconversation-local/Makefile index b9d6584..dacf270 100644 --- a/workload/docker/python/pyconversation-local/Makefile +++ b/workload/docker/python/pyconversation-local/Makefile @@ -5,7 +5,7 @@ PROJ_NAME := $(shell basename $(CURRENT_DIR)) all: build build: - docker compose build + docker-compose build --no-cache clean: rm -rf $(OUTPUT_DIR) diff --git a/workload/docker/python/pyconversation-local/compose.yaml b/workload/docker/python/pyconversation-local/compose.yaml index b08aabc..832b711 100644 --- a/workload/docker/python/pyconversation-local/compose.yaml +++ b/workload/docker/python/pyconversation-local/compose.yaml @@ -1,6 +1,6 @@ services: - dummy: - image: pyconversation + pyconv: + image: pyconversation:latest build: context: ../../../../ dockerfile: ./workload/docker/python/pyconversation-local/Dockerfile \ No newline at end of file diff --git a/workload/docker/python/pydummy/Dockerfile b/workload/docker/python/pydummy/Dockerfile index 37ea26c..2623478 100644 --- a/workload/docker/python/pydummy/Dockerfile +++ b/workload/docker/python/pydummy/Dockerfile @@ -4,8 +4,8 @@ FROM python:3.9.20 RUN apt-get update -y RUN apt-get install -y netcat-openbsd net-tools -RUN pip install --upgrade pip -RUN pip install py-spy +RUN pip install --no-cache-dir --upgrade pip +RUN pip install --no-cache-dir py-spy WORKDIR / @@ -13,7 +13,7 @@ WORKDIR / COPY ./workload/docker/python/pydummy/src/start.py /start COPY ./sdk/python/dist/spear*.whl /tmp/ -RUN pip install /tmp/spear*.whl +RUN pip install --no-cache-dir /tmp/spear*.whl RUN rm /tmp/spear*.whl CMD ["sleep", "infinity"] diff --git a/workload/docker/python/pydummy/Makefile b/workload/docker/python/pydummy/Makefile index b9d6584..dacf270 100644 --- a/workload/docker/python/pydummy/Makefile +++ b/workload/docker/python/pydummy/Makefile @@ -5,7 +5,7 @@ PROJ_NAME := $(shell basename $(CURRENT_DIR)) all: build build: - docker compose build + docker-compose build --no-cache clean: rm -rf $(OUTPUT_DIR) diff --git a/workload/docker/python/pydummy/compose.yaml b/workload/docker/python/pydummy/compose.yaml index b2dab00..4db5b0e 100644 --- a/workload/docker/python/pydummy/compose.yaml +++ b/workload/docker/python/pydummy/compose.yaml @@ -1,6 +1,6 @@ services: dummy: - image: pydummy + image: pydummy:latest build: context: ../../../../ dockerfile: ./workload/docker/python/pydummy/Dockerfile \ No newline at end of file diff --git a/workload/docker/python/pydummy/src/start.py b/workload/docker/python/pydummy/src/start.py index a525f39..7cd616f 100755 --- a/workload/docker/python/pydummy/src/start.py +++ b/workload/docker/python/pydummy/src/start.py @@ -1,14 +1,9 @@ #!/usr/bin/env python3 -import argparse -import base64 import logging -import os import sys import spear.client as client -import spear.hostcalls.tools as tools -import spear.hostcalls.transform as tf -import spear.utils.io as io +import spear.transform.chat as chat logging.basicConfig( level=logging.DEBUG, # Set the desired logging level @@ -22,13 +17,15 @@ agent = client.HostAgent() + def handle(params): """ handle the request """ logger.debug("Handling request: %s", params) - test("text-embedding-ada-002") - test("bge-large-en-v1.5") + test("gpt-4o") + #test("text-embedding-ada-002") + #test("bge-large-en-v1.5") def test(model): @@ -36,26 +33,9 @@ def test(model): test the model """ logger.info("Testing model: %s", model) - resp = agent.exec_request( - "transform", - tf.TransformRequest( - input_types=[tf.TransformType.TEXT], - output_types=[tf.TransformType.VECTOR], - operations=[tf.TransformOperation.EMBEDDINGS], - params={ - "model": model, - "input": "hi", - }, - ), - ) - if isinstance(resp, client.JsonRpcOkResp): - resp = tf.TransformResponse.schema().load(resp.result) - # base64 decode the response string - data = resp.results[0].data - logger.info("Response Len: %s", len(data)) - elif isinstance(resp, client.JsonRpcErrorResp): - raise Exception(resp) + resp = chat.chat(agent, "hi", model=model) + logger.info(resp) agent.stop() return "done" diff --git a/workload/docker/python/pytest-functionality/Dockerfile b/workload/docker/python/pytest-functionality/Dockerfile new file mode 100644 index 0000000..40ea57f --- /dev/null +++ b/workload/docker/python/pytest-functionality/Dockerfile @@ -0,0 +1,19 @@ +FROM python:3.9.20 + +# update the package lists +RUN apt-get update -y +RUN apt-get install -y netcat-openbsd net-tools + +RUN pip install --no-cache-dir --upgrade pip +RUN pip install --no-cache-dir py-spy + +WORKDIR / + +# COPY ./scripts/start.sh /start +COPY ./workload/docker/python/pytest-functionality/src/start.py /start + +COPY ./sdk/python/dist/spear*.whl /tmp/ +RUN pip install --no-cache-dir /tmp/spear*.whl +RUN rm /tmp/spear*.whl + +CMD ["sleep", "infinity"] diff --git a/workload/docker/python/pytest-functionality/Makefile b/workload/docker/python/pytest-functionality/Makefile new file mode 100644 index 0000000..dacf270 --- /dev/null +++ b/workload/docker/python/pytest-functionality/Makefile @@ -0,0 +1,13 @@ +CURRENT_DIR := $(shell pwd) +OUTPUT_DIR := $(shell pwd)/bin +PROJ_NAME := $(shell basename $(CURRENT_DIR)) + +all: build + +build: + docker-compose build --no-cache + +clean: + rm -rf $(OUTPUT_DIR) + +.PHONY: all build diff --git a/workload/docker/python/pytest-functionality/compose.yaml b/workload/docker/python/pytest-functionality/compose.yaml new file mode 100644 index 0000000..6aa88c8 --- /dev/null +++ b/workload/docker/python/pytest-functionality/compose.yaml @@ -0,0 +1,6 @@ +services: + functionality: + image: pytest-functionality:latest + build: + context: ../../../../ + dockerfile: ./workload/docker/python/pytest-functionality/Dockerfile \ No newline at end of file diff --git a/workload/docker/python/pytest-functionality/src/start.py b/workload/docker/python/pytest-functionality/src/start.py new file mode 100755 index 0000000..2849f17 --- /dev/null +++ b/workload/docker/python/pytest-functionality/src/start.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +import logging +import sys +import time + +import spear.client as client +import spear.transform.chat as chat +import spear.utils.io as io + +from spear.proto.tool import BuiltinToolID +from spear.utils.tool import register_internal_tool + +logging.basicConfig( + level=logging.DEBUG, # Set the desired logging level + # Customize the log format + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(stream=sys.stderr)], # Log to stderr +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +agent = client.HostAgent() + + +def handle(params): + """ + handle the request + """ + logger.info("Handling request: %s", params) + + logger.info("testing chat") + test_chat("gpt-4o") + + logger.info("testing speak") + test_speak("tts-1") + + logger.info("testing record") + test_record("whisper-1") + + logger.info("testing input") + test_input() + + logger.info("testing tool") + test_tool() + # test("text-embedding-ada-002") + # test("bge-large-en-v1.5") + + time.sleep(10) + # agent.stop() + + +def test_chat(model): + """ + test the model + """ + logger.info("Testing model: %s", model) + + resp = chat.chat(agent, "hi", model=model) + logger.info(resp) + resp = chat.chat(agent, "what is the time now?", + model=model, builtin_tools=[ + BuiltinToolID.BuiltinToolID.Datetime, + ]) + logger.info(resp) + + +def test_speak(model): + """ + test the model + """ + logger.info("Testing model: %s", model) + + resp = io.speak(agent, "test test test") + assert resp is not None + + +def test_record(model): + """ + test the model + """ + logger.info("Testing model: %s", model) + + resp = io.record(agent, "recording test") + assert resp is not None + + +def test_input(): + """ + test the model + """ + logger.info("Testing input") + + resp = io.input(agent, "input", True) + logger.info(resp) + + +def test_tool_cb(param1, param2): + """ + spear tool callback test function + + @param param1: test parameter 1 + @param param2: test parameter 2 + """ + logger.info("Testing tool callback %s %s", param1, param2) + return "test" + + +def test_tool(): + """ + test the model + """ + logger.info("Testing tool") + tid = register_internal_tool(agent, test_tool_cb) + logger.info("Registered tool: %d", tid) + + +if __name__ == "__main__": + agent.register_handler("handle", handle) + agent.run() diff --git a/workload/docker/python/pytools/Dockerfile b/workload/docker/python/pytools/Dockerfile index acdbad1..81a66ea 100644 --- a/workload/docker/python/pytools/Dockerfile +++ b/workload/docker/python/pytools/Dockerfile @@ -10,7 +10,7 @@ WORKDIR / COPY ./workload/docker/python/pytools/src/start.py /start COPY ./sdk/python/dist/spear*.whl /tmp/ -RUN pip install /tmp/spear*.whl +RUN pip install --no-cache-dir /tmp/spear*.whl RUN rm /tmp/spear*.whl CMD ["sleep", "infinity"] diff --git a/workload/docker/python/pytools/Makefile b/workload/docker/python/pytools/Makefile index 492bdd3..fd4ddaa 100644 --- a/workload/docker/python/pytools/Makefile +++ b/workload/docker/python/pytools/Makefile @@ -3,7 +3,7 @@ OUTPUT_DIR := $(shell pwd)/bin PROJ_NAME := $(shell basename $(CURRENT_DIR)) all: build - docker compose build + docker-compose build --no-cache build: demo diff --git a/workload/docker/python/pytools/compose.yaml b/workload/docker/python/pytools/compose.yaml index 30c8ab9..412b98f 100644 --- a/workload/docker/python/pytools/compose.yaml +++ b/workload/docker/python/pytools/compose.yaml @@ -1,6 +1,6 @@ services: - dummy: - image: pytools + pytools: + image: pytools:latest build: context: ../../../../ dockerfile: ./workload/docker/python/pytools/Dockerfile \ No newline at end of file diff --git a/workload/process/dummy/Makefile b/workload/process/dummy/Makefile deleted file mode 100644 index 7a28242..0000000 --- a/workload/process/dummy/Makefile +++ /dev/null @@ -1,11 +0,0 @@ -.PHONY: all clean - -PROJECT_ROOT := $(shell pwd)/../../.. -OUTPUT_DIR := $(PROJECT_ROOT)/bin - -all: - go build -o $(OUTPUT_DIR)/dummy_task \ - $(PROJECT_ROOT)/workload/process/dummy/main.go - -clean: - rm -rf $(OUTPUT_DIR) diff --git a/workload/process/dummy/main.go b/workload/process/dummy/main.go deleted file mode 100644 index 571c590..0000000 --- a/workload/process/dummy/main.go +++ /dev/null @@ -1,101 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "os" - "time" - - // flags support - "flag" - - "github.com/lfedgeai/spear/pkg/rpc" - "github.com/lfedgeai/spear/pkg/rpc/payload" - - // logrus - log "github.com/sirupsen/logrus" -) - -// input and output flags using -i/-o -var input string -var output string - -func init() { - flag.StringVar(&input, "i", "", "input file") - flag.StringVar(&output, "o", "", "output file") - flag.Parse() -} - -func main() { - // open input pipe and output pipe - inPipe, err := os.OpenFile(input, os.O_RDONLY, os.ModeNamedPipe) - if err != nil { - panic(err) - } - outPipe, err := os.OpenFile(output, os.O_WRONLY, os.ModeNamedPipe) - if err != nil { - panic(err) - } - - hdl := rpc.NewGuestRPCManager( - func(req *rpc.JsonRPCRequest) (*rpc.JsonRPCResponse, error) { - log.Infof("Request: %s", *req.Method) - return rpc.NewJsonRPCResponse(*req.ID, nil), nil - }, - func(resp *rpc.JsonRPCResponse) error { - log.Infof("Response: %s", resp.Result) - - // convert resp.Result to buffer - data, err := json.Marshal(resp.Result) - if err != nil { - log.Errorf("Error marshalling response: %v", err) - panic(err) - } - - if len(data) > 2048 { - log.Infof("Response: %s", data[:2048]) - } else { - log.Infof("Response: %s", data) - } - - return nil - }, - ) - hdl.SetInput(inPipe) - hdl.SetOutput(outPipe) - go hdl.Run() - - // // send an embeddings request - // embeddingsReq := transform.EmbeddingsRequest{ - // Model: "text-embedding-ada-002", - // Input: "The food was delicious and the waiter...", - // } - - // req2 := rpc.NewJsonRPCRequest(transform.HostCallEmbeddings, embeddingsReq) - // err = req2.Send(outPipe) - // if err != nil { - // panic(err) - // } - - randName := fmt.Sprintf("vdb-%d", time.Now().UnixNano()) - - // vector store ops - req3 := rpc.NewJsonRPCRequest(payload.HostCallVectorStoreCreate, payload.VectorStoreCreateRequest{ - Name: randName, - }) - err = req3.Send(outPipe) - if err != nil { - panic(err) - } - - // delete vector store - req4 := rpc.NewJsonRPCRequest(payload.HostCallVectorStoreDelete, payload.VectorStoreDeleteRequest{ - VID: 0, - }) - err = req4.Send(outPipe) - if err != nil { - panic(err) - } - - time.Sleep(5 * time.Second) -}