-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlanggraph.py
241 lines (199 loc) · 6.97 KB
/
langgraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import os
import time
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langgraph.graph.message import AnyMessage, add_messages
from langgraph.graph import END, StateGraph
from langgraph.checkpoint.sqlite import SqliteSaver
from typing import Annotated, Dict, TypedDict, List
from operator import itemgetter
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from IPython.display import Image, display
import uuid
# Set environment variables
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
# mistral_api_key = os.getenv("MISTRAL_API_KEY") # Ensure this is set
# Set up the LLM
llm = ChatGroq(temperature=0, groq_api_key="groq_api", model_name="llama3-8b-8192")
# Define the prompt template
code_gen_prompt_claude = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a coding assistant. Ensure any code you provide can be executed with all required imports and variables defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block.
\n Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
# Define the data model
class code(BaseModel):
"""Code output"""
prefix: str = Field(description="Description of the problem and approach")
imports: str = Field(description="Code block import statements")
code: str = Field(description="Code block not including import statements")
description = "Schema for code solutions to questions about LCEL."
# Set up the structured output
code_gen_chain = llm.with_structured_output(code, include_raw=False)
# Define the graph state
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
error : Binary flag for control flow to indicate whether test error was tripped
messages : With user question, error messages, reasoning
generation : Code solution
iterations : Number of tries
"""
error: str
messages: Annotated[list[AnyMessage], add_messages]
generation: str
iterations: int
# Define the nodes
def generate(state: GraphState):
"""
Generate a code solution
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
# State
messages = state["messages"]
iterations = state["iterations"]
error = state["error"]
# Solution
code_solution = code_gen_chain.invoke(messages)
messages += [
(
"assistant",
f"Here is my attempt to solve the problem: {code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
)
]
# Increment
iterations = iterations + 1
# Add delay to reduce API requests
time.sleep(1) # Wait for 1 second
return {"generation": code_solution, "messages": messages, "iterations": iterations}
def code_check(state: GraphState):
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state["messages"]
code_solution = state["generation"]
iterations = state["iterations"]
# Get solution components
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
# Check imports
try:
exec(imports)
except Exception as e:
print("---CODE IMPORT CHECK: FAILED---")
error_message = [("user", f"Your solution failed the import test. Here is the error: {e}. Reflect on this error and your prior attempt to solve the problem. (1) State what you think went wrong with the prior solution and (2) try to solve this problem again. Return the FULL SOLUTION. Use the code tool to structure the output with a prefix, imports, and code block:")]
messages += error_message
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "yes",
}
# Check execution
try:
combined_code = f"{imports}\n{code}"
# Use a shared scope for exec
global_scope = {}
exec(combined_code, global_scope)
except Exception as e:
print("---CODE BLOCK CHECK: FAILED---")
error_message = [("user", f"Your solution failed the code execution test: {e}) Reflect on this error and your prior attempt to solve the problem. (1) State what you think went wrong with the prior solution and (2) try to solve this problem again. Return the FULL SOLUTION. Use the code tool to structure the output with a prefix, imports, and code block:")]
messages += error_message
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "yes",
}
# No errors
print("---NO CODE TEST FAILURES---")
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "no",
}
def decide_to_finish(state: GraphState):
"""
Determines whether to finish.
Args:
state (dict): The current graph state
Returns:
str: Next node to call
"""
error = state["error"]
iterations = state["iterations"]
if error == "no" or iterations == max_iterations:
print("---DECISION: FINISH---")
return "end"
else:
print("---DECISION: RE-TRY SOLUTION---")
return "generate"
# Define the graph
builder = StateGraph(GraphState)
# Add nodes
builder.add_node("generate", generate) # generation solution
builder.add_node("check_code", code_check) # check code
# Build graph
builder.set_entry_point("generate")
builder.add_edge("generate", "check_code")
builder.add_conditional_edges(
"check_code",
decide_to_finish,
{
"end": END,
"generate": "generate",
},
)
# Compile the graph
memory = SqliteSaver.from_conn_string(":memory:")
graph = builder.compile(checkpointer=memory)
# Display the graph
try:
display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
except:
pass
# Run the graph
_printed = set()
thread_id = str(uuid.uuid4())
config = {
"configurable": {
# Checkpoints are accessed by thread_id
"thread_id": thread_id,
}
}
# Ask user for input
question = input("Enter your question or search query: ")
# Run the graph
max_iterations = 5 # Define the maximum number of iterations
events = graph.stream(
{"messages": [("user", question)], "iterations": 0}, config, stream_mode="values"
)
def _print_event(event, _printed):
if str(event) not in _printed:
print(event)
_printed.add(str(event))
for event in events:
_print_event(event, _printed)
# Output the final result
print("Final Result:")
print(event['generation'])