From 8a417f328c4324c53924aa549279414fc0ce87d0 Mon Sep 17 00:00:00 2001 From: Mirian Okradze Date: Tue, 31 Oct 2023 20:39:04 +0400 Subject: [PATCH] feat: implement pre retrieval and detection data source flows --- .../agents/conversational/conversational.py | 5 ++- apps/server/services/chat.py | 14 ++++++- apps/server/tools/base.py | 1 + apps/server/tools/datasources/file/file.py | 14 +++---- .../tools/datasources/get_datasource_tools.py | 27 ++++++++------ apps/server/tools/datasources/mysql/mysql.py | 12 ++---- .../tools/datasources/postgres/postgres.py | 13 ++----- apps/server/typings/agent.py | 6 +++ apps/server/utils/system_message.py | 37 ++++++------------- 9 files changed, 63 insertions(+), 66 deletions(-) diff --git a/apps/server/agents/conversational/conversational.py b/apps/server/agents/conversational/conversational.py index c80b2c2f0..ab3d5bbdf 100644 --- a/apps/server/agents/conversational/conversational.py +++ b/apps/server/agents/conversational/conversational.py @@ -29,6 +29,7 @@ def run( run_id: UUID, sender_user_id: str, run_logs_manager: RunLogsManager, + pre_retrieved_context: str, ): memory = ZepMemory( session_id=str(self.session_id), @@ -41,7 +42,9 @@ def run( memory.human_name = self.sender_name memory.ai_name = agent_with_configs.agent.name - system_message = SystemMessageBuilder(agent_with_configs).build() + system_message = SystemMessageBuilder( + agent_with_configs, pre_retrieved_context + ).build() run_logs_manager.create_system_run_log(system_message) diff --git a/apps/server/services/chat.py b/apps/server/services/chat.py index 883d215a7..1aa8ae041 100644 --- a/apps/server/services/chat.py +++ b/apps/server/services/chat.py @@ -29,7 +29,7 @@ from services.run_log import RunLogsManager from tools.datasources.get_datasource_tools import get_datasource_tools from tools.get_tools import get_agent_tools -from typings.agent import AgentWithConfigsOutput +from typings.agent import AgentWithConfigsOutput, DataSourceFlow from typings.auth import UserAccount from typings.chat import ChatMessageInput, ChatStatus, ChatUserMessageInput from typings.config import AccountSettings, ConfigInput @@ -566,7 +566,16 @@ def run_conversational_agent( agent_with_configs, tool_callback_handler, ) - tools = datasource_tools + agent_tools + + pre_retrieved_context = "" + + if agent_with_configs.configs.source_flow == DataSourceFlow.PRE_RETRIEVAL.value: + if len(datasource_tools) != 0: + pre_retrieved_context = datasource_tools[0]._run(prompt) + + tools = agent_tools + else: + tools = datasource_tools + agent_tools conversational = ConversationalAgent(sender_name, provider_account, session_id) return conversational.run( @@ -580,6 +589,7 @@ def run_conversational_agent( run_id, sender_user_id, run_logs_manager, + pre_retrieved_context, ) diff --git a/apps/server/tools/base.py b/apps/server/tools/base.py index 725b17e64..6de306a93 100644 --- a/apps/server/tools/base.py +++ b/apps/server/tools/base.py @@ -55,6 +55,7 @@ class BaseTool(LangchainBaseTool): toolkit_slug: Optional[str] = None account: Optional[AccountOutput] = None agent_with_configs: Optional[AgentWithConfigsOutput] = None + data_source_id: Optional[str] = None def get_env_key(self, key: str): return self.configs.get(key) diff --git a/apps/server/tools/datasources/file/file.py b/apps/server/tools/datasources/file/file.py index 95b2f9c97..ae6762bf3 100644 --- a/apps/server/tools/datasources/file/file.py +++ b/apps/server/tools/datasources/file/file.py @@ -12,9 +12,7 @@ class FileDatasourceSchema(BaseModel): - query: str = Field( - description="Containing Datasource Id and question in English natural language, separated by semicolon" - ) + query: str = Field(description="Containing question in English natural language") class FileDatasourceTool(BaseTool): @@ -22,7 +20,7 @@ class FileDatasourceTool(BaseTool): description = ( "useful for when you need to answer questions over File datasource.\n" - "Input is string. String is separated by semicolon. First is question in English natural language. Second is datasource ID." + "Input is a question in English natural language" ) args_schema: Type[FileDatasourceSchema] = FileDatasourceSchema @@ -34,10 +32,8 @@ def _run( ) -> str: """Ask questions over file datasource. Return result.""" - question, datasource_id = query.split(";") - configs = ConfigModel.get_configs( - db, ConfigQueryParams(datasource_id=datasource_id), self.account + db, ConfigQueryParams(datasource_id=self.data_source_id), self.account ) files_config = [config for config in configs if config.key == "files"][0] @@ -54,11 +50,11 @@ def _run( response_mode, vector_store, str(self.account.id), - datasource_id, + self.data_source_id, self.agent_with_configs, chunk_size, similarity_top_k, ) retriever.load_index() - result = retriever.query(question) + result = retriever.query(query) return result diff --git a/apps/server/tools/datasources/get_datasource_tools.py b/apps/server/tools/datasources/get_datasource_tools.py index d9755d11d..54d7b1dca 100644 --- a/apps/server/tools/datasources/get_datasource_tools.py +++ b/apps/server/tools/datasources/get_datasource_tools.py @@ -23,23 +23,26 @@ def get_datasource_tools( tools: List[BaseTool] = [] - datasource_types = [datasource.source_type for datasource in datasources] - datasource_types = list(set(datasource_types)) - - for datasource_type in datasource_types: - if datasource_type == DatasourceType.POSTGRES.value: - tools.append(PostgresDatabaseTool()) - if datasource_type == DatasourceType.MYSQL.value: - tools.append(MySQLDatabaseTool()) - if datasource_type == DatasourceType.FILE.value: - tools.append(FileDatasourceTool()) - - for tool in tools: + for data_source in datasources: + tool = None + + if data_source.source_type == DatasourceType.POSTGRES.value: + tool = PostgresDatabaseTool() + if data_source.source_type == DatasourceType.MYSQL.value: + tool = MySQLDatabaseTool() + if data_source.source_type == DatasourceType.FILE.value: + tool = FileDatasourceTool() + + tool.name = f"{data_source.name} Data Source" + tool.description = data_source.description tool.settings = settings tool.account = account tool.agent_with_configs = agent_with_configs + tool.data_source_id = str(data_source.id) if callback_handler: tool.callbacks = [callback_handler] + tools.append(tool) + return tools diff --git a/apps/server/tools/datasources/mysql/mysql.py b/apps/server/tools/datasources/mysql/mysql.py index e36318ed7..a2396eb93 100644 --- a/apps/server/tools/datasources/mysql/mysql.py +++ b/apps/server/tools/datasources/mysql/mysql.py @@ -12,7 +12,7 @@ class MySQLDatabaseSchema(BaseModel): query: str = Field( - description="Containing Datasource Id and database question in English natural language, separated by semicolon" + description="Containing database question in English natural language. It is not SQL script!" ) @@ -21,8 +21,7 @@ class MySQLDatabaseTool(BaseTool): description = ( "useful for when you need to answer questions over MySQL datasource.\n" - "Input is string. String is separated by semicolon. First is database question in English natural language. Second is datasource ID.\n" - "First part of input is English question and it is not SQL script!\n" + "Input is database question in English natural language. it is not SQL script!\n" ) args_schema: Type[MySQLDatabaseSchema] = MySQLDatabaseSchema @@ -34,11 +33,10 @@ def _run( ) -> str: """Convert natural language to SQL Query and execute. Return result.""" - question, datasource_id = query.split(";") configs = ( db.session.query(ConfigModel) .where( - ConfigModel.datasource_id == datasource_id, + ConfigModel.datasource_id == self.data_source_id, ConfigModel.is_deleted.is_(False), ) .all() @@ -59,7 +57,5 @@ def _run( uri = f"mysql+pymysql://{user}:{password}@{host}:{port}/{name}" - result = SQLQueryEngine(self.settings, self.agent_with_configs, uri).run( - question - ) + result = SQLQueryEngine(self.settings, self.agent_with_configs, uri).run(query) return result diff --git a/apps/server/tools/datasources/postgres/postgres.py b/apps/server/tools/datasources/postgres/postgres.py index 8b12787c0..1a50b3e75 100644 --- a/apps/server/tools/datasources/postgres/postgres.py +++ b/apps/server/tools/datasources/postgres/postgres.py @@ -12,7 +12,7 @@ class PostgresDatabaseSchema(BaseModel): query: str = Field( - description="Containing Datasource Id and database question in English natural language, separated by semicolon" + description="Containing database question in English natural language. It is not SQL script!" ) @@ -21,10 +21,8 @@ class PostgresDatabaseTool(BaseTool): description = ( "useful for when you need to answer questions over Postgres datasource.\n" - "Input is string. String is separated by semicolon. First is database question in English natural language. Second is datasource ID.\n" - "First part of input is English question and it is not SQL script!\n" + "Input is database question in English natural language. it is not SQL script!\n" ) - args_schema: Type[PostgresDatabaseSchema] = PostgresDatabaseSchema tool_id = "f5a8fec0-7399-42f5-a076-be3a8c85b689" @@ -34,11 +32,10 @@ def _run( ) -> str: """Convert natural language to SQL Query and execute. Return result.""" - question, datasource_id = query.split(";") configs = ( db.session.query(ConfigModel) .where( - ConfigModel.datasource_id == datasource_id, + ConfigModel.datasource_id == self.data_source_id, ConfigModel.is_deleted.is_(False), ) .all() @@ -59,7 +56,5 @@ def _run( uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{name}" - result = SQLQueryEngine(self.settings, self.agent_with_configs, uri).run( - question - ) + result = SQLQueryEngine(self.settings, self.agent_with_configs, uri).run(query) return result diff --git a/apps/server/typings/agent.py b/apps/server/typings/agent.py index c984240d9..1d246dc49 100644 --- a/apps/server/typings/agent.py +++ b/apps/server/typings/agent.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import List, Optional from pydantic import UUID4, BaseModel @@ -5,6 +6,11 @@ from typings.user import UserOutput +class DataSourceFlow(Enum): + PRE_RETRIEVAL = "pre_execution" + SOURCE_DETECTION = "source_detection" + + class AgentInput(BaseModel): name: str description: Optional[str] diff --git a/apps/server/utils/system_message.py b/apps/server/utils/system_message.py index 610369eec..187677c48 100644 --- a/apps/server/utils/system_message.py +++ b/apps/server/utils/system_message.py @@ -1,15 +1,18 @@ from typing import List, Optional -from fastapi_sqlalchemy import db - -from models.datasource import DatasourceModel from typings.agent import AgentWithConfigsOutput class SystemMessageBuilder: - def __init__(self, agent_with_configs: AgentWithConfigsOutput): + def __init__( + self, + agent_with_configs: AgentWithConfigsOutput, + pre_retrieved_context: Optional[str] = "", + ): self.agent = agent_with_configs.agent self.configs = agent_with_configs.configs + self.data_source_pre_retrieval = False + self.pre_retrieved_context = pre_retrieved_context def build(self) -> str: base_system_message = self.build_base_system_message(self.configs.text) @@ -18,9 +21,9 @@ def build(self) -> str: goals = self.build_goals(self.configs.goals) instructions = self.build_instructions(self.configs.instructions) constraints = self.build_constraints(self.configs.constraints) - data_sources = self.build_data_sources(self.configs.datasources) + context = self.build_pre_retrieved_context(self.pre_retrieved_context) - result = f"{base_system_message}{role}{description}{goals}{instructions}{constraints}{data_sources}" + result = f"{base_system_message}{role}{description}{goals}{instructions}{constraints}{context}" return result def build_base_system_message(self, text: str) -> str: @@ -70,26 +73,10 @@ def build_constraints(self, constraints: List[str]): ) return constraints - def build_data_sources(self, datasource_ids: List[str]): - """Builds the data sources section of the system message.""" - if len(datasource_ids) == 0: + def build_pre_retrieved_context(self, text: str): + if text is None or text == "": return "" - data_sources = ( - db.session.query(DatasourceModel) - .filter(DatasourceModel.id.in_(datasource_ids)) - .all() - ) - - result = ( - "DATASOURCES:" - "Data sources can be: SQL databases and files. You can use tools to get data from them.\n" - "You can use the following data sources:\n" - ) - - for data_source in data_sources: - result += f"- Data source Type: {data_source.source_type}, Data source Name: {data_source.name}, Useful for: {data_source.description}, Data source Id for tool: {data_source.id} \n" - - result += "\n" + result = "CONTEXT DATA: \n" f"{text}\n" return result