Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tj202/pr minor fixes 1 #29

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
24 changes: 22 additions & 2 deletions libs/agentc/agentc/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from agentc_core.provider import PythonTarget
from agentc_core.provider import ToolProvider
from agentc_core.version import VersionDescriptor
from typing import Literal
from typing import Union

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -361,7 +363,25 @@ def version(self) -> VersionDescriptor:
version_tuples.append(self._remote_prompt_catalog.version)
return sorted(version_tuples, key=lambda x: x.timestamp, reverse=True)[0]

def get_tools_for(
def get_item(
self,
query: str = None,
name: str = None,
annotations: str = None,
snapshot: str = LATEST_SNAPSHOT_VERSION,
limit: typing.Union[int | None] = 1,
item_type: Literal["tool", "prompt", "agent"] = None,
) -> Union[list[typing.Any] | Prompt | None]:
if item_type == "tool":
return self._get_tools_for(query, name, annotations, snapshot, limit)
elif item_type == "prompt":
return self._get_prompt_for(query, name, annotations, snapshot)
elif item_type == "agent":
pass
TJ202 marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"Unknown item type: {item_type}, expected 'tool', 'prompt', or 'agent'.")

def _get_tools_for(
self,
query: str = None,
name: str = None,
Expand All @@ -387,7 +407,7 @@ def get_tools_for(
else:
return [self._tool_provider.get(name=name, annotations=annotations, snapshot=snapshot)]

def get_prompt_for(
def _get_prompt_for(
self,
query: str = None,
name: str = None,
Expand Down
10 changes: 5 additions & 5 deletions libs/agentc/tests/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_local_tool_provider(tmp_path):
)
os.chdir(td)
provider = Provider()
tools = provider.get_tools_for("searching travel blogs")
tools = provider.get_item(query="searching travel blogs", item_type="tool")
assert len(tools) == 1
assert tools[0].__name__ == "get_travel_blog_snippets_from_user_interests"

Expand All @@ -45,7 +45,7 @@ def test_local_prompt_provider(tmp_path):
)
os.chdir(td)
provider = Provider()
prompt = provider.get_prompt_for("asking a user their location")
prompt = provider.get_item(query="asking a user their location", item_type="prompt")
assert prompt.tools is None
assert prompt.meta.name == "get_user_location"

Expand All @@ -62,8 +62,8 @@ def test_local_provider(tmp_path):
)
os.chdir(td)
provider = Provider()
prompt = provider.get_prompt_for("asking a user their location")
tools = provider.get_tools_for("searching travel blogs")
prompt = provider.get_item(query="asking a user their location", item_type="prompt")
tools = provider.get_item(query="searching travel blogs", item_type="tool")
assert len(tools) == 1
assert tools[0].__name__ == "get_travel_blog_snippets_from_user_interests"
assert prompt.tools is None
Expand All @@ -84,7 +84,7 @@ def test_db_tool_provider(tmp_path, isolated_server_factory):
os.chdir(td)
os.remove((pathlib.Path(td) / DEFAULT_CATALOG_FOLDER / DEFAULT_TOOL_CATALOG_NAME).absolute())
provider = Provider(bucket="travel-sample")
tools = provider.get_tools_for("searching travel blogs")
tools = provider.get_item(query="searching travel blogs", item_type="tool")
assert len(tools) == 1
assert tools[0].__name__ == "get_travel_blog_snippets_from_user_interests"

Expand Down
23 changes: 12 additions & 11 deletions templates/agents/with_controlflow/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import dotenv

# Make sure you populate your .env file with the correct credentials!
dotenv.load_dotenv()

import agentc
import agentc.auditor
import agentc.langchain
Expand All @@ -11,13 +6,18 @@
import controlflow.events
import controlflow.orchestration
import controlflow.tools
import dotenv
import langchain_openai
import os
import pydantic
import uuid

from pydantic import SecretStr
from utils import TaskFactory

# Make sure you populate your .env file with the correct credentials!
dotenv.load_dotenv()

# The Agent Catalog provider serves versioned tools and prompts.
# For a comprehensive list of what parameters can be set here, see the class documentation.
# Parameters can also be set with environment variables (e.g., bucket = $AGENT_CATALOG_BUCKET).
Expand All @@ -29,9 +29,9 @@
# The 'values' of this dictionary map to actual values required by the tool.
# In this case, we get the Couchbase connection string, username, and password from environment variables.
secrets={
"CB_CONN_STRING": os.getenv("CB_CONN_STRING"),
"CB_USERNAME": os.getenv("CB_USERNAME"),
"CB_PASSWORD": os.getenv("CB_PASSWORD"),
"CB_CONN_STRING": SecretStr(os.getenv("CB_CONN_STRING")),
"CB_USERNAME": SecretStr(os.getenv("CB_USERNAME")),
"CB_PASSWORD": SecretStr(os.getenv("CB_PASSWORD")),
},
)

Expand Down Expand Up @@ -85,7 +85,7 @@ class EndpointsType(pydantic.BaseModel):
while True:
endpoints = task_factory.run(
# Search for prompts using your provider.
prompt=provider.get_prompt_for(query="asking for source and destination airports"),
prompt=provider.get_item(query="asking for source and destination airports", item_type="prompt"),
# All other arguments are forwarded to the ControlFlow Task constructor.
# Check out their docs here: https://controlflow.ai/concepts/tasks#task-properties
result_type=EndpointsType,
Expand All @@ -94,13 +94,14 @@ class EndpointsType(pydantic.BaseModel):
# We "draw" implicit dependency edges by using the results of previous tasks.
# In this example, all tasks are executed eagerly (though there is some limited support for lazy evaluation).
travel_routes = task_factory.run(
prompt=provider.get_prompt_for(query="finding routes between airports"),
prompt=provider.get_item(query="finding routes between airports", item_type="prompt"),
context={"source_airport": endpoints.source_airport, "destination_airport": endpoints.dest_airport},
result_type=str,
)
print(f"Your routes are: {travel_routes}")
is_continue = task_factory.run(
prompt=provider.get_prompt_for(query="after addressing a user's request"), result_type=[True, False]
prompt=provider.get_item(query="after addressing a user's request", item_type="prompt"),
result_type=[True, False],
)
if not is_continue:
break
Expand Down
Loading