diff --git a/sdk/python/Makefile b/sdk/python/Makefile index 68301bc..489c58a 100644 --- a/sdk/python/Makefile +++ b/sdk/python/Makefile @@ -7,3 +7,4 @@ generate: mv agent_protocol/main.py agent_protocol/server.py rm -rf agent_protocol/routers rm agent_protocol/dependencies.py + black . \ No newline at end of file diff --git a/sdk/python/agent_protocol/__init__.py b/sdk/python/agent_protocol/__init__.py index a09870a..398e78c 100644 --- a/sdk/python/agent_protocol/__init__.py +++ b/sdk/python/agent_protocol/__init__.py @@ -1,16 +1,17 @@ from .agent import Agent, StepHandler, TaskHandler, base_router as router -from .models import Artifact, StepRequestBody, TaskRequestBody +from .models import Artifact, Status, StepRequestBody, TaskRequestBody from .db import Step, Task, TaskDB __all__ = [ "Agent", "Artifact", + "Status", "Step", "StepHandler", + "StepRequestBody", "Task", "TaskDB", - "StepRequestBody", "TaskHandler", "TaskRequestBody", "router", diff --git a/sdk/python/agent_protocol/agent.py b/sdk/python/agent_protocol/agent.py index 8c894c5..34307c6 100644 --- a/sdk/python/agent_protocol/agent.py +++ b/sdk/python/agent_protocol/agent.py @@ -1,27 +1,26 @@ import asyncio import os +from uuid import uuid4 import aiofiles from fastapi import APIRouter, UploadFile, Form, File from fastapi.responses import FileResponse from hypercorn.asyncio import serve from hypercorn.config import Config -from typing import Awaitable, Callable, List, Optional, Annotated +from typing import Callable, List, Optional, Annotated, Coroutine, Any -from .db import InMemoryTaskDB, TaskDB +from .db import InMemoryTaskDB, Task, TaskDB, Step from .server import app from .models import ( TaskRequestBody, - Step, StepRequestBody, Artifact, - Task, Status, ) -StepHandler = Callable[[Step], Awaitable[Step]] -TaskHandler = Callable[[Task], Awaitable[None]] +StepHandler = Callable[[Step], Coroutine[Any, Any, Step]] +TaskHandler = Callable[[Task], Coroutine[Any, Any, None]] _task_handler: Optional[TaskHandler] @@ -89,12 +88,17 @@ async def execute_agent_task_step( """ Execute a step in the specified agent task. """ + if not _step_handler: + raise Exception("Step handler not defined") + task = await Agent.db.get_task(task_id) step = next(filter(lambda x: x.status == Status.created, task.steps), None) if not step: raise Exception("No steps to execute") + step.status = Status.running + step.input = body.input if body else None step.additional_input = body.additional_input if body else None @@ -109,7 +113,7 @@ async def execute_agent_task_step( response_model=Step, tags=["agent"], ) -async def get_agent_task_step(task_id: str, step_id: str = ...) -> Step: +async def get_agent_task_step(task_id: str, step_id: str) -> Step: """ Get details about a specified task step. """ @@ -142,14 +146,15 @@ async def upload_agent_task_artifacts( """ Upload an artifact for the specified task. """ + file_name = file.filename or str(uuid4()) await Agent.db.get_task(task_id) - artifact = await Agent.db.create_artifact(task_id, file.filename, relative_path) + artifact = await Agent.db.create_artifact(task_id, file_name, relative_path) path = Agent.get_artifact_folder(task_id, artifact) if not os.path.exists(path): os.makedirs(path) - async with aiofiles.open(os.path.join(path, file.filename), "wb") as f: + async with aiofiles.open(os.path.join(path, file_name), "wb") as f: while content := await file.read(1024 * 1024): # async read chunk ~1MiB await f.write(content) diff --git a/sdk/python/agent_protocol/db.py b/sdk/python/agent_protocol/db.py index e48ce49..a500e79 100644 --- a/sdk/python/agent_protocol/db.py +++ b/sdk/python/agent_protocol/db.py @@ -12,13 +12,24 @@ class Task(APITask): steps: List[Step] = [] +class NotFoundException(Exception): + """ + Exception raised when a resource is not found. + """ + + def __init__(self, item_name: str, item_id: str): + self.item_name = item_name + self.item_id = item_id + super().__init__(f"{item_name} with {item_id} not found.") + + class TaskDB(ABC): async def create_task( self, input: Optional[str], - additional_input: Optional[str] = None, - artifacts: List[Artifact] = None, - steps: List[Step] = None, + additional_input: Any = None, + artifacts: Optional[List[Artifact]] = None, + steps: Optional[List[Step]] = None, ) -> Task: raise NotImplementedError @@ -26,6 +37,7 @@ async def create_step( self, task_id: str, name: Optional[str] = None, + input: Optional[str] = None, is_last: bool = False, additional_properties: Optional[Dict[str, str]] = None, ) -> Step: @@ -52,7 +64,9 @@ async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: async def list_tasks(self) -> List[Task]: raise NotImplementedError - async def list_steps(self, task_id: str) -> List[Step]: + async def list_steps( + self, task_id: str, status: Optional[Status] = None + ) -> List[Step]: raise NotImplementedError @@ -62,9 +76,9 @@ class InMemoryTaskDB(TaskDB): async def create_task( self, input: Optional[str], - additional_input: Optional[str] = None, - artifacts: List[Artifact] = None, - steps: List[Step] = None, + additional_input: Any = None, + artifacts: Optional[List[Artifact]] = None, + steps: Optional[List[Step]] = None, ) -> Task: if not steps: steps = [] @@ -85,14 +99,16 @@ async def create_step( self, task_id: str, name: Optional[str] = None, + input: Optional[str] = None, is_last=False, - additional_properties: Dict[str, Any] = None, + additional_properties: Optional[Dict[str, Any]] = None, ) -> Step: step_id = str(uuid.uuid4()) step = Step( task_id=task_id, step_id=step_id, name=name, + input=input, status=Status.created, is_last=is_last, additional_properties=additional_properties, @@ -104,14 +120,14 @@ async def create_step( async def get_task(self, task_id: str) -> Task: task = self._tasks.get(task_id, None) if not task: - raise Exception(f"Task with id {task_id} not found") + raise NotFoundException("Task", task_id) return task async def get_step(self, task_id: str, step_id: str) -> Step: task = await self.get_task(task_id) step = next(filter(lambda s: s.task_id == task_id, task.steps), None) if not step: - raise Exception(f"Step with id {step_id} not found") + raise NotFoundException("Step", step_id) return step async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: @@ -120,7 +136,7 @@ async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None ) if not artifact: - raise Exception(f"Artifact with id {artifact_id} not found") + raise NotFoundException("Artifact", artifact_id) return artifact async def create_artifact( @@ -146,6 +162,11 @@ async def create_artifact( async def list_tasks(self) -> List[Task]: return [task for task in self._tasks.values()] - async def list_steps(self, task_id: str) -> List[Step]: + async def list_steps( + self, task_id: str, status: Optional[Status] = None + ) -> List[Step]: task = await self.get_task(task_id) - return [step for step in task.steps] + steps = task.steps + if status: + steps = list(filter(lambda s: s.status == status, steps)) + return steps diff --git a/sdk/python/agent_protocol/middlewares.py b/sdk/python/agent_protocol/middlewares.py new file mode 100644 index 0000000..0443396 --- /dev/null +++ b/sdk/python/agent_protocol/middlewares.py @@ -0,0 +1,13 @@ +from fastapi import Request +from fastapi.responses import PlainTextResponse + +from agent_protocol.db import NotFoundException + + +async def not_found_exception_handler( + request: Request, exc: NotFoundException +) -> PlainTextResponse: + return PlainTextResponse( + str(exc), + status_code=404, + ) diff --git a/sdk/python/agent_protocol/models.py b/sdk/python/agent_protocol/models.py index 0214b9b..6a8fe4e 100644 --- a/sdk/python/agent_protocol/models.py +++ b/sdk/python/agent_protocol/models.py @@ -1,6 +1,6 @@ # generated by fastapi-codegen: # filename: ../../openapi.yml -# timestamp: 2023-08-07T12:14:43+00:00 +# timestamp: 2023-08-11T14:24:22+00:00 from __future__ import annotations @@ -12,65 +12,111 @@ class TaskInput(BaseModel): __root__: Any = Field( - ..., description="Input parameters for the task. Any value is allowed." + ..., + description="Input parameters for the task. Any value is allowed.", + example='{\n"debug": false,\n"mode": "benchmarks"\n}', ) class Artifact(BaseModel): - artifact_id: str = Field(..., description="ID of the artifact.") - file_name: str = Field(..., description="Filename of the artifact.") + artifact_id: str = Field( + ..., + description="ID of the artifact.", + example="b225e278-8b4c-4f99-a696-8facf19f0e56", + ) + file_name: str = Field( + ..., description="Filename of the artifact.", example="main.py" + ) relative_path: Optional[str] = Field( - None, description="Relative path of the artifact in the agent's workspace." + None, + description="Relative path of the artifact in the agent's workspace.", + example="python/code/", ) class ArtifactUpload(BaseModel): file: bytes = Field(..., description="File to upload.") relative_path: Optional[str] = Field( - None, description="Relative path of the artifact in the agent's workspace." + None, + description="Relative path of the artifact in the agent's workspace.", + example="python/code", ) class StepInput(BaseModel): __root__: Any = Field( - ..., description="Input parameters for the task step. Any value is allowed." + ..., + description="Input parameters for the task step. Any value is allowed.", + example='{\n"file_to_refactor": "models.py"\n}', ) class StepOutput(BaseModel): __root__: Any = Field( - ..., description="Output that the task step has produced. Any value is allowed." + ..., + description="Output that the task step has produced. Any value is allowed.", + example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}', ) class TaskRequestBody(BaseModel): - input: Optional[str] = Field(None, description="Input prompt for the task.") + input: Optional[str] = Field( + None, + description="Input prompt for the task.", + example="Write the words you receive to the file 'output.txt'.", + ) additional_input: Optional[TaskInput] = None class Task(TaskRequestBody): - task_id: str = Field(..., description="The ID of the task.") + task_id: str = Field( + ..., + description="The ID of the task.", + example="50da533e-3904-4401-8a07-c49adf88b5eb", + ) artifacts: List[Artifact] = Field( - [], description="A list of artifacts that the task has produced." + [], + description="A list of artifacts that the task has produced.", + example=[ + "7a49f31c-f9c6-4346-a22c-e32bc5af4d8e", + "ab7b4091-2560-4692-a4fe-d831ea3ca7d6", + ], ) class StepRequestBody(BaseModel): - input: Optional[str] = Field(None, description="Input prompt for the step.") + input: Optional[str] = Field( + None, description="Input prompt for the step.", example="Washington" + ) additional_input: Optional[StepInput] = None class Status(Enum): created = "created" + running = "running" completed = "completed" class Step(StepRequestBody): - task_id: str = Field(..., description="The ID of the task this step belongs to.") - step_id: str = Field(..., description="The ID of the task step.") - name: Optional[str] = Field(None, description="The name of the task step.") + task_id: str = Field( + ..., + description="The ID of the task this step belongs to.", + example="50da533e-3904-4401-8a07-c49adf88b5eb", + ) + step_id: str = Field( + ..., + description="The ID of the task step.", + example="6bb1801a-fd80-45e8-899a-4dd723cc602e", + ) + name: Optional[str] = Field( + None, description="The name of the task step.", example="Write to file" + ) status: Status = Field(..., description="The status of the task step.") - output: Optional[str] = Field(None, description="Output of the task step.") + output: Optional[str] = Field( + None, + description="Output of the task step.", + example="I am going to use the write_to_file command and write Washington to a file called output.txt None: await Agent.db.create_step(task.task_id, StepTypes.PLAN) -async def step_handler(step: Step): +async def step_handler(step: Step) -> Step: task = await Agent.db.get_task(step.task_id) if step.name == StepTypes.PLAN: return await _generate_shared_deps(step) diff --git a/sdk/python/pyproject.toml b/sdk/python/pyproject.toml index 45fc2c8..8852380 100644 --- a/sdk/python/pyproject.toml +++ b/sdk/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "agent-protocol" -version = "0.2.4" +version = "0.3.0" description = "API for interacting with Agent" authors = ["e2b "] license = "MIT"