Skip to content

Commit

Permalink
support toolcall from agent directly
Browse files Browse the repository at this point in the history
  • Loading branch information
wilsonwang371 committed Dec 19, 2024
1 parent b50be2f commit fc8cb53
Show file tree
Hide file tree
Showing 35 changed files with 555 additions and 363 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ go 1.23.0
require (
github.com/docker/docker v27.3.1+incompatible
github.com/docker/go-connections v0.5.0
github.com/google/uuid v1.6.0
github.com/kbinani/screenshot v0.0.0-20240820160931-a8a2c5d0e191
github.com/qdrant/go-client v1.12.0
github.com/schollz/progressbar/v3 v3.17.1
Expand Down
4 changes: 2 additions & 2 deletions pkg/rpc/payload/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
type ChatCompletionRequest struct {
Messages []ChatMessageV2 `json:"messages"`
Model string `json:"model"`
ToolsetId string `json:"toolset_id"`
ToolsetId int `json:"toolset_id"` // valid if toolset_id is positive
}

// marshal operations
Expand Down Expand Up @@ -67,7 +67,7 @@ func (r *ChatCompletionResponseV2) Unmarshal(data []byte) error {
type ChatCompletionRequestV2 struct {
Messages []ChatMessageV2 `json:"messages"`
Model string `json:"model"`
ToolsetId string `json:"toolset_id"`
ToolsetId int `json:"toolset_id"`
}

func (r *ChatCompletionRequestV2) Marshal() ([]byte, error) {
Expand Down
1 change: 1 addition & 0 deletions pkg/rpc/payload/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const (
HostCallToolNew = "tool.new"
HostCallToolsetNew = "toolset.new"
HostCallToolsetInstallBuiltins = "toolset.install.builtins"
HostCallToolCall = "tool.call"

HostCallInput = "input"
HostCallSpeak = "speak"
Expand Down
21 changes: 14 additions & 7 deletions pkg/rpc/payload/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,37 @@ type NewToolParam struct {
type NewToolRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ToolsetID int `json:"toolset_id"`
Params []NewToolParam `json:"params"`
Cb string `json:"cb"`
}

type NewToolResponse struct {
Tid string `json:"tool_id"`
ToolsetID int `json:"tool_id"`
}

type NewToolsetRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ToolIds []string `json:"tool_ids"`
Name string `json:"name"`
Description string `json:"description"`
WorkloadName string `json:"workload_name"`
}

type NewToolsetResponse struct {
Tsid string `json:"toolset_id"`
ToolsetID int `json:"toolset_id"`
}

type ToolsetInstallBuiltinsRequest struct {
Tsid string `json:"toolset_id"`
ToolsetID int `json:"toolset_id"`
}

type ToolsetInstallBuiltinsResponse struct {
Tsid string `json:"toolset_id"`
ToolsetID int `json:"toolset_id"`
}

type ToolCallRequest struct {
ToolsetID int `json:"toolset_id"`
ToolID int `json:"tool_id"`
Params map[string]interface{} `json:"params"`
}

func (r *NewToolRequest) Marshal() ([]byte, error) {
Expand Down
6 changes: 5 additions & 1 deletion pkg/rpc/transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ import (
"github.com/lfedgeai/spear/pkg/utils"
)

func ChatCompletion(rpcMgr *GuestRPCManager, model string, msgs []payload.ChatMessageV2, tsId string) ([]payload.ChatMessageV2, error) {
func ChatCompletion(rpcMgr *GuestRPCManager, model string, msgs []payload.ChatMessageV2) ([]payload.ChatMessageV2, error) {
return ChatCompletionWithToolset(rpcMgr, model, msgs, -1)
}

func ChatCompletionWithToolset(rpcMgr *GuestRPCManager, model string, msgs []payload.ChatMessageV2, tsId int) ([]payload.ChatMessageV2, error) {
req := &payload.ChatCompletionRequestV2{
Model: model,
Messages: msgs,
Expand Down
23 changes: 18 additions & 5 deletions sdk/python/spear/hostcalls/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,47 @@ class NewToolParams:
"""
The parameters for the newtool hostcall.
"""

name: str
type: str
description: str
required: bool


@dataclass_json
@dataclass
class NewToolRequest:
"""
The request object for the newtool hostcall.
"""

name: str
description: str
toolset_id: int
params: list[NewToolParams]
cb: str


@dataclass_json
@dataclass
class NewToolResponse:
"""
The response object for the newtool hostcall.
"""
tool_id: str

tool_id: int


@dataclass_json
@dataclass
class NewToolsetRequest:
"""
The request object for the newtoolset hostcall.
"""

name: str
description: str
tool_ids: list[str]
workload_name: str = ""


@dataclass_json
Expand All @@ -53,20 +61,25 @@ class NewToolsetResponse:
"""
The response object for the newtoolset hostcall.
"""
toolset_id: str

toolset_id: int


@dataclass_json
@dataclass
class ToolsetInstallBuiltinsRequest:
"""
The request object for the toolset.install_builtins hostcall.
"""
toolset_id: str

toolset_id: int


@dataclass_json
@dataclass
class ToolsetInstallBuiltinsResponse:
"""
The response object for the toolset.install_builtins hostcall.
"""
toolset_id: str

toolset_id: int
44 changes: 44 additions & 0 deletions sdk/python/spear/utils/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python3
import logging

import spear.client as client
import spear.hostcalls.transform as tf

logger = logging.getLogger(__name__)


def chat_completion(
agent: client.HostAgent,
prompt: str,
toolset_id: int = -1,
role: str = "user",
model: str = "gpt-4o",
) -> list[tf.ChatChoice]:
"""
get user input
"""
resp = agent.exec_request(
"transform",
tf.TransformRequest(
input_types=[tf.TransformType.TEXT],
output_types=[tf.TransformType.TEXT],
operations=[tf.TransformOperation.LLM, tf.TransformOperation.TOOLS],
params={
"model": model,
"messages": [
tf.ChatMessageV2(
metadata=tf.ChatMessageV2Metadata(role=role),
content=prompt,
)
],
"toolset_id": toolset_id,
},
),
)

if isinstance(resp, client.JsonRpcOkResp):
resp = tf.TransformResponse.schema().load(resp.result)
resp = tf.ChatResponseV2.schema().load(resp.results[0].data)
return resp.messages
else:
raise ValueError(f"Error: {resp.message}")
32 changes: 32 additions & 0 deletions sdk/python/spear/utils/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env python3
import logging

import spear.client as client
import spear.hostcalls.tools as tools

logger = logging.getLogger(__name__)


def new_toolset(
agent: client.HostAgent, name: str, description: str, workload_name: str = None
) -> str:
"""
create a new toolset
"""
req = {
"name": name,
"description": description,
}
if workload_name:
req["workload_name"] = workload_name
resp = agent.exec_request(
"toolset.new",
req,
)

if isinstance(resp, client.JsonRpcOkResp):
logger.debug("Toolset created with id: %s", resp.result)
resp = tools.NewToolsetResponse.schema().load(resp.result)
return resp.toolset_id
else:
raise ValueError(f"Error creating toolset: {resp.message}")
Loading

0 comments on commit fc8cb53

Please sign in to comment.