Skip to content

Commit

Permalink
feat: implement pre retrieval and detection data source flows
Browse files Browse the repository at this point in the history
  • Loading branch information
okradze committed Oct 31, 2023
1 parent 9494b93 commit 8a417f3
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 66 deletions.
5 changes: 4 additions & 1 deletion apps/server/agents/conversational/conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def run(
run_id: UUID,
sender_user_id: str,
run_logs_manager: RunLogsManager,
pre_retrieved_context: str,
):
memory = ZepMemory(
session_id=str(self.session_id),
Expand All @@ -41,7 +42,9 @@ def run(
memory.human_name = self.sender_name
memory.ai_name = agent_with_configs.agent.name

system_message = SystemMessageBuilder(agent_with_configs).build()
system_message = SystemMessageBuilder(
agent_with_configs, pre_retrieved_context
).build()

run_logs_manager.create_system_run_log(system_message)

Expand Down
14 changes: 12 additions & 2 deletions apps/server/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from services.run_log import RunLogsManager
from tools.datasources.get_datasource_tools import get_datasource_tools
from tools.get_tools import get_agent_tools
from typings.agent import AgentWithConfigsOutput
from typings.agent import AgentWithConfigsOutput, DataSourceFlow
from typings.auth import UserAccount
from typings.chat import ChatMessageInput, ChatStatus, ChatUserMessageInput
from typings.config import AccountSettings, ConfigInput
Expand Down Expand Up @@ -566,7 +566,16 @@ def run_conversational_agent(
agent_with_configs,
tool_callback_handler,
)
tools = datasource_tools + agent_tools

pre_retrieved_context = ""

if agent_with_configs.configs.source_flow == DataSourceFlow.PRE_RETRIEVAL.value:
if len(datasource_tools) != 0:
pre_retrieved_context = datasource_tools[0]._run(prompt)

tools = agent_tools
else:
tools = datasource_tools + agent_tools

conversational = ConversationalAgent(sender_name, provider_account, session_id)
return conversational.run(
Expand All @@ -580,6 +589,7 @@ def run_conversational_agent(
run_id,
sender_user_id,
run_logs_manager,
pre_retrieved_context,
)


Expand Down
1 change: 1 addition & 0 deletions apps/server/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class BaseTool(LangchainBaseTool):
toolkit_slug: Optional[str] = None
account: Optional[AccountOutput] = None
agent_with_configs: Optional[AgentWithConfigsOutput] = None
data_source_id: Optional[str] = None

def get_env_key(self, key: str):
return self.configs.get(key)
Expand Down
14 changes: 5 additions & 9 deletions apps/server/tools/datasources/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@


class FileDatasourceSchema(BaseModel):
query: str = Field(
description="Containing Datasource Id and question in English natural language, separated by semicolon"
)
query: str = Field(description="Containing question in English natural language")


class FileDatasourceTool(BaseTool):
name = "File Datasource Q&A"

description = (
"useful for when you need to answer questions over File datasource.\n"
"Input is string. String is separated by semicolon. First is question in English natural language. Second is datasource ID."
"Input is a question in English natural language"
)

args_schema: Type[FileDatasourceSchema] = FileDatasourceSchema
Expand All @@ -34,10 +32,8 @@ def _run(
) -> str:
"""Ask questions over file datasource. Return result."""

question, datasource_id = query.split(";")

configs = ConfigModel.get_configs(
db, ConfigQueryParams(datasource_id=datasource_id), self.account
db, ConfigQueryParams(datasource_id=self.data_source_id), self.account
)
files_config = [config for config in configs if config.key == "files"][0]

Expand All @@ -54,11 +50,11 @@ def _run(
response_mode,
vector_store,
str(self.account.id),
datasource_id,
self.data_source_id,
self.agent_with_configs,
chunk_size,
similarity_top_k,
)
retriever.load_index()
result = retriever.query(question)
result = retriever.query(query)
return result
27 changes: 15 additions & 12 deletions apps/server/tools/datasources/get_datasource_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,26 @@ def get_datasource_tools(

tools: List[BaseTool] = []

datasource_types = [datasource.source_type for datasource in datasources]
datasource_types = list(set(datasource_types))

for datasource_type in datasource_types:
if datasource_type == DatasourceType.POSTGRES.value:
tools.append(PostgresDatabaseTool())
if datasource_type == DatasourceType.MYSQL.value:
tools.append(MySQLDatabaseTool())
if datasource_type == DatasourceType.FILE.value:
tools.append(FileDatasourceTool())

for tool in tools:
for data_source in datasources:
tool = None

if data_source.source_type == DatasourceType.POSTGRES.value:
tool = PostgresDatabaseTool()
if data_source.source_type == DatasourceType.MYSQL.value:
tool = MySQLDatabaseTool()
if data_source.source_type == DatasourceType.FILE.value:
tool = FileDatasourceTool()

tool.name = f"{data_source.name} Data Source"
tool.description = data_source.description
tool.settings = settings
tool.account = account
tool.agent_with_configs = agent_with_configs
tool.data_source_id = str(data_source.id)

if callback_handler:
tool.callbacks = [callback_handler]

tools.append(tool)

return tools
12 changes: 4 additions & 8 deletions apps/server/tools/datasources/mysql/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class MySQLDatabaseSchema(BaseModel):
query: str = Field(
description="Containing Datasource Id and database question in English natural language, separated by semicolon"
description="Containing database question in English natural language. It is not SQL script!"
)


Expand All @@ -21,8 +21,7 @@ class MySQLDatabaseTool(BaseTool):

description = (
"useful for when you need to answer questions over MySQL datasource.\n"
"Input is string. String is separated by semicolon. First is database question in English natural language. Second is datasource ID.\n"
"First part of input is English question and it is not SQL script!\n"
"Input is database question in English natural language. it is not SQL script!\n"
)

args_schema: Type[MySQLDatabaseSchema] = MySQLDatabaseSchema
Expand All @@ -34,11 +33,10 @@ def _run(
) -> str:
"""Convert natural language to SQL Query and execute. Return result."""

question, datasource_id = query.split(";")
configs = (
db.session.query(ConfigModel)
.where(
ConfigModel.datasource_id == datasource_id,
ConfigModel.datasource_id == self.data_source_id,
ConfigModel.is_deleted.is_(False),
)
.all()
Expand All @@ -59,7 +57,5 @@ def _run(

uri = f"mysql+pymysql://{user}:{password}@{host}:{port}/{name}"

result = SQLQueryEngine(self.settings, self.agent_with_configs, uri).run(
question
)
result = SQLQueryEngine(self.settings, self.agent_with_configs, uri).run(query)
return result
13 changes: 4 additions & 9 deletions apps/server/tools/datasources/postgres/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class PostgresDatabaseSchema(BaseModel):
query: str = Field(
description="Containing Datasource Id and database question in English natural language, separated by semicolon"
description="Containing database question in English natural language. It is not SQL script!"
)


Expand All @@ -21,10 +21,8 @@ class PostgresDatabaseTool(BaseTool):

description = (
"useful for when you need to answer questions over Postgres datasource.\n"
"Input is string. String is separated by semicolon. First is database question in English natural language. Second is datasource ID.\n"
"First part of input is English question and it is not SQL script!\n"
"Input is database question in English natural language. it is not SQL script!\n"
)

args_schema: Type[PostgresDatabaseSchema] = PostgresDatabaseSchema

tool_id = "f5a8fec0-7399-42f5-a076-be3a8c85b689"
Expand All @@ -34,11 +32,10 @@ def _run(
) -> str:
"""Convert natural language to SQL Query and execute. Return result."""

question, datasource_id = query.split(";")
configs = (
db.session.query(ConfigModel)
.where(
ConfigModel.datasource_id == datasource_id,
ConfigModel.datasource_id == self.data_source_id,
ConfigModel.is_deleted.is_(False),
)
.all()
Expand All @@ -59,7 +56,5 @@ def _run(

uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{name}"

result = SQLQueryEngine(self.settings, self.agent_with_configs, uri).run(
question
)
result = SQLQueryEngine(self.settings, self.agent_with_configs, uri).run(query)
return result
6 changes: 6 additions & 0 deletions apps/server/typings/agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from enum import Enum
from typing import List, Optional

from pydantic import UUID4, BaseModel

from typings.user import UserOutput


class DataSourceFlow(Enum):
PRE_RETRIEVAL = "pre_execution"
SOURCE_DETECTION = "source_detection"


class AgentInput(BaseModel):
name: str
description: Optional[str]
Expand Down
37 changes: 12 additions & 25 deletions apps/server/utils/system_message.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from typing import List, Optional

from fastapi_sqlalchemy import db

from models.datasource import DatasourceModel
from typings.agent import AgentWithConfigsOutput


class SystemMessageBuilder:
def __init__(self, agent_with_configs: AgentWithConfigsOutput):
def __init__(
self,
agent_with_configs: AgentWithConfigsOutput,
pre_retrieved_context: Optional[str] = "",
):
self.agent = agent_with_configs.agent
self.configs = agent_with_configs.configs
self.data_source_pre_retrieval = False
self.pre_retrieved_context = pre_retrieved_context

def build(self) -> str:
base_system_message = self.build_base_system_message(self.configs.text)
Expand All @@ -18,9 +21,9 @@ def build(self) -> str:
goals = self.build_goals(self.configs.goals)
instructions = self.build_instructions(self.configs.instructions)
constraints = self.build_constraints(self.configs.constraints)
data_sources = self.build_data_sources(self.configs.datasources)
context = self.build_pre_retrieved_context(self.pre_retrieved_context)

result = f"{base_system_message}{role}{description}{goals}{instructions}{constraints}{data_sources}"
result = f"{base_system_message}{role}{description}{goals}{instructions}{constraints}{context}"
return result

def build_base_system_message(self, text: str) -> str:
Expand Down Expand Up @@ -70,26 +73,10 @@ def build_constraints(self, constraints: List[str]):
)
return constraints

def build_data_sources(self, datasource_ids: List[str]):
"""Builds the data sources section of the system message."""
if len(datasource_ids) == 0:
def build_pre_retrieved_context(self, text: str):
if text is None or text == "":
return ""

data_sources = (
db.session.query(DatasourceModel)
.filter(DatasourceModel.id.in_(datasource_ids))
.all()
)

result = (
"DATASOURCES:"
"Data sources can be: SQL databases and files. You can use tools to get data from them.\n"
"You can use the following data sources:\n"
)

for data_source in data_sources:
result += f"- Data source Type: {data_source.source_type}, Data source Name: {data_source.name}, Useful for: {data_source.description}, Data source Id for tool: {data_source.id} \n"

result += "\n"
result = "CONTEXT DATA: \n" f"{text}\n"

return result

0 comments on commit 8a417f3

Please sign in to comment.