Skip to content

Commit

Permalink
added a new action to infer in existing session
Browse files Browse the repository at this point in the history
  • Loading branch information
dakshbhardwaj committed Aug 1, 2024
1 parent 6c196a8 commit 66b4ed0
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 22 deletions.
3 changes: 3 additions & 0 deletions agents/sirji_agents/llm/generic/system_prompts/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def system_prompt(self):
action_list.add(ActionEnum[action])
allowed_response_templates_str += '\n' + allowed_response_templates(AgentEnum.ANY, AgentEnum.EXECUTOR, action_list) + '\n'

action_list_researcher = permissions_dict[(AgentEnum.ANY, AgentEnum.RESEARCHER)]
allowed_response_templates_str += '\n' + allowed_response_templates(AgentEnum.ANY, AgentEnum.RESEARCHER, action_list_researcher) + '\n'

allowed_response_templates_str += textwrap.dedent(f"""For updating in project folder use either {ActionEnum.FIND_AND_REPLACE.name}, {ActionEnum.INSERT_ABOVE.name} or {ActionEnum.INSERT_BELOW.name} actions. Ensure you provide the exact matching string in find from file, with the exact number of lines and proper indentation for insert and replace actions.""") + '\n'
allowed_response_templates_str += '\n' + allowed_response_templates(AgentEnum.ANY, AgentEnum.CALLER, permissions_dict[(AgentEnum.ANY, AgentEnum.CALLER)]) + '\n'

Expand Down
10 changes: 10 additions & 0 deletions agents/sirji_agents/researcher/cleanup/openai_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def delete_file(self, file_path):
response = self.client.files.delete(file_path)
print(response)
self.logger.info(response)
except Exception as e:
print(e)
self.logger.error(e)

@retry_on_exception()
def delete_thread(self, thread_id):
try:
response = self.client.beta.threads.delete(thread_id = thread_id)
print(response)
self.logger.info(response)
except Exception as e:
print(e)
self.logger.error(e)
50 changes: 38 additions & 12 deletions agents/sirji_agents/researcher/inferer/openai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import time
from openai import OpenAI
Expand All @@ -18,6 +19,7 @@ def __init__(self, init_payload):
"""

self.init_payload = init_payload
self.thread_id = self.init_payload.get('thread_id', None)

# Fetch OpenAI API key from an environment variable
api_key = os.environ.get("SIRJI_MODEL_PROVIDER_API_KEY")
Expand All @@ -37,6 +39,7 @@ def __init__(self, init_payload):
self.assistant_id = self.init_payload['assistant_id']

self.logger.info("Completed initializing OpenAI Assistant Inferer")
self.assistant_details_path = os.path.join(self._get_run_path(), "assistant_details.json")

def infer(self, problem_statement):
self.logger.info("Started inferring using OpenAI Assistant Inferer")
Expand All @@ -56,26 +59,49 @@ def infer(self, problem_statement):
# content=problem_statement,
# )

self.thread_id = None

try:
self.thread_id = self.create_thread(problem_statement)
if not self.thread_id:
self.thread_id = self.create_thread(problem_statement)
if self.thread_id:
try:
self.logger.info("File path: %s", self.assistant_details_path)
with open(self.assistant_details_path, 'r') as file:
assistant_details = json.load(file)
self.logger.info("Loaded assistant_details: %s", assistant_details)

with open(self.assistant_details_path, 'w') as file:

thread_ids_map = assistant_details.get('thread_ids_map')
complete_session_id = self.init_payload.get('complete_session_id')

thread_ids = thread_ids_map.get(complete_session_id, None)
if thread_ids is None:
thread_ids_map[complete_session_id] = []

thread_ids_map[complete_session_id].append(self.thread_id)

self.logger.info("Saving thread_id: %s", self.thread_id)
self.logger.info("Saving assistant_details: %s", assistant_details)
json.dump(assistant_details, file, indent=4)
except Exception as save_e:
self.logger.error("Failed to save thread_id: %s", str(save_e))
print('self.thread_id', self.thread_id)

response = self._fetch_response()
except Exception as e:
self.logger.error("An error occurred during inference: %s", str(e))
response = 'An error occurred during inference', 0, 0
finally:
if self.thread_id:
try:
deleteResponse = self.client.beta.threads.delete(thread_id=self.thread_id)
self.logger.info("Thread deleted successfully: %s", deleteResponse)
print('deleteResponse', deleteResponse)
except Exception as delete_e:
self.logger.error("Failed to delete thread: %s", str(delete_e))
print('self.thread_id', self.thread_id)


return response

def _get_run_path(self):
run_path = os.environ.get("SIRJI_RUN_PATH")
if run_path is None:
raise ValueError(
"SIRJI_RUN_PATH is not set as an environment variable")
return run_path


@retry_on_exception()
def create_thread(self, problem_statement):
Expand Down
26 changes: 26 additions & 0 deletions agents/sirji_agents/researcher/researcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def message(self, input_message):
return self._sync_codebase(parsed_message), 0, 0
elif action == ActionEnum.INFER.name:
return self._handle_infer(parsed_message)
elif action == ActionEnum.INFER_IN_EXISTING_THREAD.name:
return self._handle_infer_in_existing_thread(parsed_message)

# if action == ActionEnum.TRAIN_USING_SEARCH_TERM.name:
# return self._handle_train_using_search_term(parsed_message)
Expand Down Expand Up @@ -208,6 +210,28 @@ def replace_sirji_tags(self, text):
text = text.replace(sirji_tag, file_content)

return text

def _handle_infer_in_existing_thread(self, parsed_message):
"""Private method to handle inference requests in an existing thread."""
self.logger.info(f"Infering in existing thread: {parsed_message.get('BODY')}")

complete_session_id = self.init_payload.get("complete_session_id")


if complete_session_id is None:
return "Error: No active thread found", 0, 0

thread_id = self.init_payload.get("thread_ids_map").get(complete_session_id)[-1]

if thread_id is None:
return "Error: There is no existing thread found.", 0, 0


self.init_payload['thread_id'] = thread_id

response, prompt_tokens, completion_tokens = self._infer(parsed_message.get('BODY'))

return self._generate_message(parsed_message.get('TO'), parsed_message.get('FROM'), response), prompt_tokens, completion_tokens

def _handle_infer(self, parsed_message):
"""Private method to handle inference requests."""
Expand Down Expand Up @@ -494,10 +518,12 @@ def create_assistant(self, body):
assistant_details = {
"assistant_id": assistant.id,
"vector_store_id": vector_store.id,
"thread_ids_map": {},
"status": "active"
}

print(assistant_details)
self.logger.info("Assistant details: %s", assistant_details)
assistant_details_path = os.path.join(self._get_run_path(), "assistant_details.json")
with open(assistant_details_path, 'w') as f:
json.dump(assistant_details, f, indent=4)
Expand Down
1 change: 1 addition & 0 deletions messages/sirji_messages/action_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ class ActionEnum(Enum):
STORE_IN_AGENT_OUTPUT = auto()
LOG_STEPS = auto()
SYNC_CODEBASE = auto()
INFER_IN_EXISTING_THREAD = auto()
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import textwrap

from sirji_messages import AgentEnum, ActionEnum

from .base import BaseMessages

class InferInExistingThread(BaseMessages):

def __init__(self):
self.action = ActionEnum.INFER_IN_EXISTING_THREAD.name
self.to_agent = AgentEnum.RESEARCHER.name

super().__init__()


def sample(self):
return self.generate({
"from_agent_id": "{{Your Agent ID}}",
"step": "Provide the step number here for the ongoing step if any.",
"summary": "{{Display a concise summary to the user, describing the action using the present continuous tense.}}",
"body": textwrap.dedent("""
{{query}}""")})

def description(self):
return "Infer using existing thread"

def instructions(self):
return []
4 changes: 3 additions & 1 deletion messages/sirji_messages/messages/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .actions.log_steps import LogSteps
from .actions.sync_codebase import SyncCodebase
from .actions.create_assistant import CreateAssistantMessage
from .actions.infer_in_existing_thread import InferInExistingThread


class MetaMessageFactory(type):
Expand Down Expand Up @@ -72,6 +73,7 @@ class MessageFactory(metaclass=MetaMessageFactory):
ActionEnum.DO_NOTHING: DoNothing,
ActionEnum.LOG_STEPS: LogSteps,
ActionEnum.SYNC_CODEBASE: SyncCodebase,
ActionEnum.CREATE_ASSISTANT: CreateAssistantMessage
ActionEnum.CREATE_ASSISTANT: CreateAssistantMessage,
ActionEnum.INFER_IN_EXISTING_THREAD: InferInExistingThread
}

11 changes: 7 additions & 4 deletions messages/sirji_messages/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
ActionEnum.EXTRACT_DEPENDENCIES,
ActionEnum.STORE_IN_SCRATCH_PAD,
ActionEnum.DO_NOTHING,
ActionEnum.LOG_STEPS,
ActionEnum.INFER,
ActionEnum.CREATE_ASSISTANT,
ActionEnum.SYNC_CODEBASE
ActionEnum.LOG_STEPS
},
(AgentEnum.ANY, AgentEnum.SIRJI_USER): {
ActionEnum.QUESTION
Expand All @@ -35,6 +32,12 @@
(AgentEnum.ORCHESTRATOR, AgentEnum.ANY): {
ActionEnum.INVOKE_AGENT,
ActionEnum.INVOKE_AGENT_EXISTING_SESSION
},
(AgentEnum.ANY, AgentEnum.RESEARCHER): {
ActionEnum.INFER,
ActionEnum.CREATE_ASSISTANT,
ActionEnum.SYNC_CODEBASE,
ActionEnum.INFER_IN_EXISTING_THREAD
}
}

Expand Down
14 changes: 11 additions & 3 deletions sirji/vscode-extension/src/py_scripts/agents/research_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def read_file(self, input_file_path):
contents = file.read()
return contents

def main(self, agent_id):
def main(self, agent_id, agent_session_id, agent_callstack):
sirji_installation_dir = os.environ.get("SIRJI_INSTALLATION_DIR")
sirji_run_path = os.environ.get("SIRJI_RUN_PATH")

Expand Down Expand Up @@ -113,7 +113,8 @@ def main(self, agent_id):
assistant_details = self.read_assistant_details()

init_payload = assistant_details

init_payload['complete_session_id'] = agent_callstack + '.' + agent_session_id

response, prompt_tokens_consumed, completion_tokens_consumed = self.process_message(message_str, conversations, init_payload)

input_tokens += prompt_tokens_consumed
Expand All @@ -125,5 +126,12 @@ def main(self, agent_id):
self.write_conversations_to_file(conversation_file_path, conversations, input_tokens, output_tokens, max_input_tokens_for_a_prompt, max_output_tokens_for_a_prompt, 'gpt-4o')

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process interactions.")
parser.add_argument("--agent_session_id", required=True, help="Agent Session Id")
parser.add_argument("--agent_callstack", required=True, help="Agent Call Stack")


args = parser.parse_args()
agent_runner = ResearchAgentRunner()
agent_runner.main('RESEARCHER')
agent_runner.main('RESEARCHER',args.agent_session_id, args.agent_callstack)

10 changes: 10 additions & 0 deletions sirji/vscode-extension/src/py_scripts/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def cleanup_assistant(self, assistant_id):
cleanup_instance = CleanupFactory.get_instance()
cleanup_instance.delete_assistant(assistant_id)

def cleanup_thread(self, thread_id):
cleanup_instance = CleanupFactory.get_instance()
cleanup_instance.delete_thread(thread_id)

def cleanup_file(self, file_path):
cleanup_instance = CleanupFactory.get_instance()
cleanup_instance.delete_file(file_path)
Expand Down Expand Up @@ -75,6 +79,12 @@ def main(self):
if assistant_details and assistant_details.get('status') == 'active':
if assistant_details.get('vector_store_id'):
self.cleanup_vector_store(assistant_details.get('vector_store_id'))
if assistant_details.get('thread_ids_map'):
thread_ids_array = assistant_details.get('thread_ids_map').values()
flattended_thread_ids_array = sum(thread_ids_array, [])
for thread_id in flattended_thread_ids_array:
self.cleanup_thread(thread_id)

if assistant_details.get('assistant_id'):
self.cleanup_assistant(assistant_details.get('assistant_id'))
# Update the status in assistant_details.json to 'deleted'
Expand Down
7 changes: 5 additions & 2 deletions sirji/vscode-extension/src/utils/facilitator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -735,12 +735,15 @@ export class Facilitator {
case ACTOR_ENUM.RESEARCHER:
console.log('Researcher message', parsedMessage);
try {
let agentCallstack = oThis.stackManager.getStack();
let sessionId = oThis.sessionManager?.reuseSession(agentCallstack);
await spawnAdapter(
oThis.context,
oThis.sirjiInstallationFolderPath,
oThis.sirjiRunFolderPath,
oThis.projectRootPath,
path.join(__dirname, '..', 'py_scripts', 'agents', 'research_agent.py')
path.join(__dirname, '..', 'py_scripts', 'agents', 'research_agent.py'),
['--agent_callstack', agentCallstack, '--agent_session_id', sessionId ?? '']
);
} catch (error) {
oThis.sendErrorToChatPanel(error);
Expand Down Expand Up @@ -783,7 +786,7 @@ export class Facilitator {
} catch (error) {
oThis.sendErrorToChatPanel(error);
keepFacilitating = false;
break;
break;
}

const agentConversationFilePath = path.join(oThis.sirjiRunFolderPath, 'conversations', `${agentCallstack}.${sessionId}.json`);
Expand Down

0 comments on commit 66b4ed0

Please sign in to comment.