Skip to content

Commit

Permalink
chore: Minor code cleanup. The unpredictability of the LLM response i…
Browse files Browse the repository at this point in the history
…s a show-stopper.
  • Loading branch information
anirbanbasu committed Aug 19, 2024
1 parent cfe591d commit 1f3526a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 38 deletions.
41 changes: 22 additions & 19 deletions src/coder_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from langchain_core.runnables import RunnableConfig


import json
from code_executor import CodeExecutor
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import END, StateGraph, START
Expand Down Expand Up @@ -172,7 +173,6 @@ def solve(self, state: AgentState) -> dict:
}
# Have we been presented with examples?
has_examples = bool(state.get(constants.AGENT_STATE__KEY_EXAMPLES))
ic(f"State in solve: {state}")
# If `draft`` is requested in the state then output a candidate solution
output_key = (
constants.AGENT_STATE__KEY_CANDIDATE
Expand Down Expand Up @@ -234,33 +234,36 @@ def evaluate(self, state: AgentState) -> dict:
Returns:
dict: The updated state of the agent.
"""
ic(f"State in evaluate: {state}")
test_cases = state[constants.AGENT_STATE__KEY_TEST_CASES]
# Extract the `AIMessage` that is expected to contain the code from the last call to the solver that was NOT to generate a candidate solution.
ai_message: AIMessage = state[constants.AGENT_STATE__KEY_MESSAGES][-1]
# ai_message is a list of dictionaries.
# tool_call_args = ai_message.content[0]
if (
ai_message.tool_calls[-1][constants.AGENT_TOOL_CALL__NAME]
!= CoderOutput.__name__
):
return {
constants.AGENT_STATE__KEY_MESSAGES: [
self.format_as_tool_message(
response=f"Invalid tool call `{ai_message.tool_calls[-1][constants.AGENT_TOOL_CALL__NAME]}`. You should call the tool `{CoderOutput.__name__}` to format your response.",
ai_message=ai_message,
)
],
constants.AGENT_STATE__KEY_STATUS: constants.AGENT_NODE__EVALUATE_STATUS_ERROR,
}
tool_call_args = ai_message.tool_calls[-1][constants.AGENT_TOOL_CALL__ARGS]
# Extract the code from the tool call.
code: str = tool_call_args[constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE]
solution: dict = None
if ai_message.tool_calls:
solution = ai_message.tool_calls[-1][constants.AGENT_TOOL_CALL__ARGS]
if (
ai_message.tool_calls[-1][constants.AGENT_TOOL_CALL__NAME]
!= CoderOutput.__name__
):
return {
constants.AGENT_STATE__KEY_MESSAGES: [
self.format_as_tool_message(
response=f"Invalid tool call `{ai_message.tool_calls[-1][constants.AGENT_TOOL_CALL__NAME]}`. You should call the tool `{CoderOutput.__name__}` to format your response.",
ai_message=ai_message,
)
],
constants.AGENT_STATE__KEY_STATUS: constants.AGENT_NODE__EVALUATE_STATUS_ERROR,
}
else:
solution = json.loads(ai_message.content)
# Extract the code from the solution.
code: str = solution[constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE]
if not code:
return {
constants.AGENT_STATE__KEY_MESSAGES: [
self.format_as_tool_message(
response="No code was generated. Please try again using the correct Python code.",
response="No code was generated. Please generate correct Python code.",
ai_message=ai_message,
)
],
Expand Down
6 changes: 4 additions & 2 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,12 @@
Finally, a well-documented and working Python 3 `code` for your solution. Do not use external libraries. Your code must be able to accept inputs from `sys.stdin` and write the final output to `sys.stdout` (or, to `sys.stderr` in case of errors).
Please format your response as a JSON dictionary, using `reasoning`, `pseudocode`, and `code` as keys.
Optional examples of similar problems and solutions (may not be in Python):
You may be provided below with ptional examples of similar problems and solutions (may not be in Python).
[BEGIN EXAMPLES]
{examples}
[END EXAMPLES]
Given problem and your conversation with the user about it:
The problem and your conversation with the user about it are given below.
"""


Expand Down
35 changes: 18 additions & 17 deletions src/webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from coder_agent import CoderOutput, MultiAgentDirectedGraph, TestCase
from utils import parse_env
import json

try:
from icecream import ic
Expand Down Expand Up @@ -244,23 +245,23 @@ def find_solution(
constants.AGENT_STATE__KEY_MESSAGES
][-1]
# FIXME: Why are there more than one tool calls to the same tool?
if (
response.tool_calls[-1][constants.AGENT_TOOL_CALL__NAME]
== CoderOutput.__name__
):
tool_call_args = response.tool_calls[-1][
constants.AGENT_TOOL_CALL__ARGS
]
# coder_output: CoderOutput = CoderOutput.parse_json(tool_call_args)
yield [
tool_call_args[
constants.PYDANTIC_MODEL__CODE_OUTPUT__REASONING
],
tool_call_args[
constants.PYDANTIC_MODEL__CODE_OUTPUT__PSEUDOCODE
],
tool_call_args[constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE],
]
solution: dict = None
if response.tool_calls:
if (
response.tool_calls[-1][constants.AGENT_TOOL_CALL__NAME]
== CoderOutput.__name__
):
solution = response.tool_calls[-1][
constants.AGENT_TOOL_CALL__ARGS
]
else:
solution = json.loads(response.content)
ic(solution)
yield [
solution[constants.PYDANTIC_MODEL__CODE_OUTPUT__REASONING],
solution[constants.PYDANTIC_MODEL__CODE_OUTPUT__PSEUDOCODE],
solution[constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE],
]

def add_test_case(
self, test_cases: list[TestCase] | None, test_case_in: str, test_case_out: str
Expand Down

0 comments on commit 1f3526a

Please sign in to comment.