diff --git a/agents/sirji_agents/llm/generic/system_prompts/default.py b/agents/sirji_agents/llm/generic/system_prompts/default.py index 3ddb9a0..f703b57 100644 --- a/agents/sirji_agents/llm/generic/system_prompts/default.py +++ b/agents/sirji_agents/llm/generic/system_prompts/default.py @@ -99,6 +99,9 @@ def system_prompt(self): action_list.add(ActionEnum[action]) allowed_response_templates_str += '\n' + allowed_response_templates(AgentEnum.ANY, AgentEnum.EXECUTOR, action_list) + '\n' + action_list_researcher = permissions_dict[(AgentEnum.ANY, AgentEnum.RESEARCHER)] + allowed_response_templates_str += '\n' + allowed_response_templates(AgentEnum.ANY, AgentEnum.RESEARCHER, action_list_researcher) + '\n' + allowed_response_templates_str += textwrap.dedent(f"""For updating in project folder use either {ActionEnum.FIND_AND_REPLACE.name}, {ActionEnum.INSERT_ABOVE.name} or {ActionEnum.INSERT_BELOW.name} actions. Ensure you provide the exact matching string in find from file, with the exact number of lines and proper indentation for insert and replace actions.""") + '\n' allowed_response_templates_str += '\n' + allowed_response_templates(AgentEnum.ANY, AgentEnum.CALLER, permissions_dict[(AgentEnum.ANY, AgentEnum.CALLER)]) + '\n' diff --git a/agents/sirji_agents/researcher/cleanup/openai_cleanup.py b/agents/sirji_agents/researcher/cleanup/openai_cleanup.py index b475054..c94df84 100644 --- a/agents/sirji_agents/researcher/cleanup/openai_cleanup.py +++ b/agents/sirji_agents/researcher/cleanup/openai_cleanup.py @@ -49,6 +49,16 @@ def delete_file(self, file_path): response = self.client.files.delete(file_path) print(response) self.logger.info(response) + except Exception as e: + print(e) + self.logger.error(e) + + @retry_on_exception() + def delete_thread(self, thread_id): + try: + response = self.client.beta.threads.delete(thread_id = thread_id) + print(response) + self.logger.info(response) except Exception as e: print(e) self.logger.error(e) \ No newline at end of file diff --git a/agents/sirji_agents/researcher/inferer/openai_assistant.py b/agents/sirji_agents/researcher/inferer/openai_assistant.py index 4973607..1d6fd2a 100644 --- a/agents/sirji_agents/researcher/inferer/openai_assistant.py +++ b/agents/sirji_agents/researcher/inferer/openai_assistant.py @@ -1,3 +1,4 @@ +import json import os import time from openai import OpenAI @@ -18,6 +19,7 @@ def __init__(self, init_payload): """ self.init_payload = init_payload + self.thread_id = self.init_payload.get('thread_id', None) # Fetch OpenAI API key from an environment variable api_key = os.environ.get("SIRJI_MODEL_PROVIDER_API_KEY") @@ -37,6 +39,7 @@ def __init__(self, init_payload): self.assistant_id = self.init_payload['assistant_id'] self.logger.info("Completed initializing OpenAI Assistant Inferer") + self.assistant_details_path = os.path.join(self._get_run_path(), "assistant_details.json") def infer(self, problem_statement): self.logger.info("Started inferring using OpenAI Assistant Inferer") @@ -56,26 +59,49 @@ def infer(self, problem_statement): # content=problem_statement, # ) - self.thread_id = None try: - self.thread_id = self.create_thread(problem_statement) + if not self.thread_id: + self.thread_id = self.create_thread(problem_statement) + if self.thread_id: + try: + self.logger.info("File path: %s", self.assistant_details_path) + with open(self.assistant_details_path, 'r') as file: + assistant_details = json.load(file) + self.logger.info("Loaded assistant_details: %s", assistant_details) + + with open(self.assistant_details_path, 'w') as file: + + thread_ids_map = assistant_details.get('thread_ids_map') + complete_session_id = self.init_payload.get('complete_session_id') + + thread_ids = thread_ids_map.get(complete_session_id, None) + if thread_ids is None: + thread_ids_map[complete_session_id] = [] + + thread_ids_map[complete_session_id].append(self.thread_id) + + self.logger.info("Saving thread_id: %s", self.thread_id) + self.logger.info("Saving assistant_details: %s", assistant_details) + json.dump(assistant_details, file, indent=4) + except Exception as save_e: + self.logger.error("Failed to save thread_id: %s", str(save_e)) + print('self.thread_id', self.thread_id) + response = self._fetch_response() except Exception as e: self.logger.error("An error occurred during inference: %s", str(e)) response = 'An error occurred during inference', 0, 0 - finally: - if self.thread_id: - try: - deleteResponse = self.client.beta.threads.delete(thread_id=self.thread_id) - self.logger.info("Thread deleted successfully: %s", deleteResponse) - print('deleteResponse', deleteResponse) - except Exception as delete_e: - self.logger.error("Failed to delete thread: %s", str(delete_e)) - print('self.thread_id', self.thread_id) - + return response + def _get_run_path(self): + run_path = os.environ.get("SIRJI_RUN_PATH") + if run_path is None: + raise ValueError( + "SIRJI_RUN_PATH is not set as an environment variable") + return run_path + @retry_on_exception() def create_thread(self, problem_statement): diff --git a/agents/sirji_agents/researcher/researcher.py b/agents/sirji_agents/researcher/researcher.py index b2acb87..35aa33d 100644 --- a/agents/sirji_agents/researcher/researcher.py +++ b/agents/sirji_agents/researcher/researcher.py @@ -138,6 +138,8 @@ def message(self, input_message): return self._sync_codebase(parsed_message), 0, 0 elif action == ActionEnum.INFER.name: return self._handle_infer(parsed_message) + elif action == ActionEnum.INFER_IN_EXISTING_THREAD.name: + return self._handle_infer_in_existing_thread(parsed_message) # if action == ActionEnum.TRAIN_USING_SEARCH_TERM.name: # return self._handle_train_using_search_term(parsed_message) @@ -208,6 +210,28 @@ def replace_sirji_tags(self, text): text = text.replace(sirji_tag, file_content) return text + + def _handle_infer_in_existing_thread(self, parsed_message): + """Private method to handle inference requests in an existing thread.""" + self.logger.info(f"Infering in existing thread: {parsed_message.get('BODY')}") + + complete_session_id = self.init_payload.get("complete_session_id") + + + if complete_session_id is None: + return "Error: No active thread found", 0, 0 + + thread_id = self.init_payload.get("thread_ids_map").get(complete_session_id)[-1] + + if thread_id is None: + return "Error: There is no existing thread found.", 0, 0 + + + self.init_payload['thread_id'] = thread_id + + response, prompt_tokens, completion_tokens = self._infer(parsed_message.get('BODY')) + + return self._generate_message(parsed_message.get('TO'), parsed_message.get('FROM'), response), prompt_tokens, completion_tokens def _handle_infer(self, parsed_message): """Private method to handle inference requests.""" @@ -494,10 +518,12 @@ def create_assistant(self, body): assistant_details = { "assistant_id": assistant.id, "vector_store_id": vector_store.id, + "thread_ids_map": {}, "status": "active" } print(assistant_details) + self.logger.info("Assistant details: %s", assistant_details) assistant_details_path = os.path.join(self._get_run_path(), "assistant_details.json") with open(assistant_details_path, 'w') as f: json.dump(assistant_details, f, indent=4) diff --git a/messages/sirji_messages/action_enum.py b/messages/sirji_messages/action_enum.py index 5097e1b..f802721 100644 --- a/messages/sirji_messages/action_enum.py +++ b/messages/sirji_messages/action_enum.py @@ -38,3 +38,4 @@ class ActionEnum(Enum): STORE_IN_AGENT_OUTPUT = auto() LOG_STEPS = auto() SYNC_CODEBASE = auto() + INFER_IN_EXISTING_THREAD = auto() diff --git a/messages/sirji_messages/messages/actions/infer_in_existing_thread.py b/messages/sirji_messages/messages/actions/infer_in_existing_thread.py new file mode 100644 index 0000000..ea89901 --- /dev/null +++ b/messages/sirji_messages/messages/actions/infer_in_existing_thread.py @@ -0,0 +1,28 @@ +import textwrap + +from sirji_messages import AgentEnum, ActionEnum + +from .base import BaseMessages + +class InferInExistingThread(BaseMessages): + + def __init__(self): + self.action = ActionEnum.INFER_IN_EXISTING_THREAD.name + self.to_agent = AgentEnum.RESEARCHER.name + + super().__init__() + + + def sample(self): + return self.generate({ + "from_agent_id": "{{Your Agent ID}}", + "step": "Provide the step number here for the ongoing step if any.", + "summary": "{{Display a concise summary to the user, describing the action using the present continuous tense.}}", + "body": textwrap.dedent(""" + {{query}}""")}) + + def description(self): + return "Infer using existing thread" + + def instructions(self): + return [] diff --git a/messages/sirji_messages/messages/factory.py b/messages/sirji_messages/messages/factory.py index 8dd5cd3..5da9b2e 100644 --- a/messages/sirji_messages/messages/factory.py +++ b/messages/sirji_messages/messages/factory.py @@ -28,6 +28,7 @@ from .actions.log_steps import LogSteps from .actions.sync_codebase import SyncCodebase from .actions.create_assistant import CreateAssistantMessage +from .actions.infer_in_existing_thread import InferInExistingThread class MetaMessageFactory(type): @@ -72,6 +73,7 @@ class MessageFactory(metaclass=MetaMessageFactory): ActionEnum.DO_NOTHING: DoNothing, ActionEnum.LOG_STEPS: LogSteps, ActionEnum.SYNC_CODEBASE: SyncCodebase, - ActionEnum.CREATE_ASSISTANT: CreateAssistantMessage + ActionEnum.CREATE_ASSISTANT: CreateAssistantMessage, + ActionEnum.INFER_IN_EXISTING_THREAD: InferInExistingThread } diff --git a/messages/sirji_messages/permissions.py b/messages/sirji_messages/permissions.py index 5cc1db5..42c1dc8 100644 --- a/messages/sirji_messages/permissions.py +++ b/messages/sirji_messages/permissions.py @@ -18,10 +18,7 @@ ActionEnum.EXTRACT_DEPENDENCIES, ActionEnum.STORE_IN_SCRATCH_PAD, ActionEnum.DO_NOTHING, - ActionEnum.LOG_STEPS, - ActionEnum.INFER, - ActionEnum.CREATE_ASSISTANT, - ActionEnum.SYNC_CODEBASE + ActionEnum.LOG_STEPS }, (AgentEnum.ANY, AgentEnum.SIRJI_USER): { ActionEnum.QUESTION @@ -35,6 +32,12 @@ (AgentEnum.ORCHESTRATOR, AgentEnum.ANY): { ActionEnum.INVOKE_AGENT, ActionEnum.INVOKE_AGENT_EXISTING_SESSION + }, + (AgentEnum.ANY, AgentEnum.RESEARCHER): { + ActionEnum.INFER, + ActionEnum.CREATE_ASSISTANT, + ActionEnum.SYNC_CODEBASE, + ActionEnum.INFER_IN_EXISTING_THREAD } } diff --git a/sirji/vscode-extension/src/py_scripts/agents/research_agent.py b/sirji/vscode-extension/src/py_scripts/agents/research_agent.py index 405adc3..c6a4f49 100644 --- a/sirji/vscode-extension/src/py_scripts/agents/research_agent.py +++ b/sirji/vscode-extension/src/py_scripts/agents/research_agent.py @@ -72,7 +72,7 @@ def read_file(self, input_file_path): contents = file.read() return contents - def main(self, agent_id): + def main(self, agent_id, agent_session_id, agent_callstack): sirji_installation_dir = os.environ.get("SIRJI_INSTALLATION_DIR") sirji_run_path = os.environ.get("SIRJI_RUN_PATH") @@ -113,7 +113,8 @@ def main(self, agent_id): assistant_details = self.read_assistant_details() init_payload = assistant_details - + init_payload['complete_session_id'] = agent_callstack + '.' + agent_session_id + response, prompt_tokens_consumed, completion_tokens_consumed = self.process_message(message_str, conversations, init_payload) input_tokens += prompt_tokens_consumed @@ -125,5 +126,12 @@ def main(self, agent_id): self.write_conversations_to_file(conversation_file_path, conversations, input_tokens, output_tokens, max_input_tokens_for_a_prompt, max_output_tokens_for_a_prompt, 'gpt-4o') if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process interactions.") + parser.add_argument("--agent_session_id", required=True, help="Agent Session Id") + parser.add_argument("--agent_callstack", required=True, help="Agent Call Stack") + + + args = parser.parse_args() agent_runner = ResearchAgentRunner() - agent_runner.main('RESEARCHER') \ No newline at end of file + agent_runner.main('RESEARCHER',args.agent_session_id, args.agent_callstack) + \ No newline at end of file diff --git a/sirji/vscode-extension/src/py_scripts/cleanup.py b/sirji/vscode-extension/src/py_scripts/cleanup.py index 6e46038..ebb4b27 100644 --- a/sirji/vscode-extension/src/py_scripts/cleanup.py +++ b/sirji/vscode-extension/src/py_scripts/cleanup.py @@ -30,6 +30,10 @@ def cleanup_assistant(self, assistant_id): cleanup_instance = CleanupFactory.get_instance() cleanup_instance.delete_assistant(assistant_id) + def cleanup_thread(self, thread_id): + cleanup_instance = CleanupFactory.get_instance() + cleanup_instance.delete_thread(thread_id) + def cleanup_file(self, file_path): cleanup_instance = CleanupFactory.get_instance() cleanup_instance.delete_file(file_path) @@ -75,6 +79,12 @@ def main(self): if assistant_details and assistant_details.get('status') == 'active': if assistant_details.get('vector_store_id'): self.cleanup_vector_store(assistant_details.get('vector_store_id')) + if assistant_details.get('thread_ids_map'): + thread_ids_array = assistant_details.get('thread_ids_map').values() + flattended_thread_ids_array = sum(thread_ids_array, []) + for thread_id in flattended_thread_ids_array: + self.cleanup_thread(thread_id) + if assistant_details.get('assistant_id'): self.cleanup_assistant(assistant_details.get('assistant_id')) # Update the status in assistant_details.json to 'deleted' diff --git a/sirji/vscode-extension/src/utils/facilitator.ts b/sirji/vscode-extension/src/utils/facilitator.ts index 781f20e..dbb915b 100644 --- a/sirji/vscode-extension/src/utils/facilitator.ts +++ b/sirji/vscode-extension/src/utils/facilitator.ts @@ -735,12 +735,15 @@ export class Facilitator { case ACTOR_ENUM.RESEARCHER: console.log('Researcher message', parsedMessage); try { + let agentCallstack = oThis.stackManager.getStack(); + let sessionId = oThis.sessionManager?.reuseSession(agentCallstack); await spawnAdapter( oThis.context, oThis.sirjiInstallationFolderPath, oThis.sirjiRunFolderPath, oThis.projectRootPath, - path.join(__dirname, '..', 'py_scripts', 'agents', 'research_agent.py') + path.join(__dirname, '..', 'py_scripts', 'agents', 'research_agent.py'), + ['--agent_callstack', agentCallstack, '--agent_session_id', sessionId ?? ''] ); } catch (error) { oThis.sendErrorToChatPanel(error); @@ -783,7 +786,7 @@ export class Facilitator { } catch (error) { oThis.sendErrorToChatPanel(error); keepFacilitating = false; - break; + break; } const agentConversationFilePath = path.join(oThis.sirjiRunFolderPath, 'conversations', `${agentCallstack}.${sessionId}.json`);