Skip to content

Commit

Permalink
Http error handling (#129)
Browse files Browse the repository at this point in the history
* upgrade openai and langchain

Co-authored-by: Jacob Vanmeter <jacobvm04@gmail.com>

* extend supabase timeout

* Handle moderation errors + other streaming errors explicitly

* Streaming error handling

* ✨ caching and skeletons (#127)

* Revert "✨ caching and skeletons (#127)"

This reverts commit 15e649e.

---------

Co-authored-by: hyusap <paulayush@gmail.com>
Co-authored-by: Vineeth Voruganti <13438633+VVoruganti@users.noreply.github.com>
  • Loading branch information
3 people committed Dec 12, 2023
1 parent a850d0e commit 3ae592e
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 159 deletions.
43 changes: 31 additions & 12 deletions agent/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
)
from langchain.prompts import load_prompt, ChatPromptTemplate
from langchain.schema import AIMessage, HumanMessage, BaseMessage

from openai import BadRequestError

from dotenv import load_dotenv

from collections.abc import AsyncIterator
Expand Down Expand Up @@ -54,12 +57,11 @@ def think(cls, cache: Conversation, input: str):
])
chain = thought_prompt | cls.llm

cache.add_message("thought", HumanMessage(content=input))
def save_new_messages(ai_response):
cache.add_message("response", HumanMessage(content=input))
cache.add_message("response", AIMessage(content=ai_response))

return Streamable(
chain.astream({}, {"tags": ["thought"], "metadata": {"conversation_id": cache.conversation_id, "user_id": cache.user_id}}),
lambda thought: cache.add_message("thought", AIMessage(content=thought))
)
return Streamable(chain.astream({}, {"tags": ["thought"], "metadata": {"conversation_id": cache.conversation_id, "user_id": cache.user_id}}), save_new_messages)

@classmethod
@sentry_sdk.trace
Expand All @@ -72,13 +74,12 @@ def respond(cls, cache: Conversation, thought: str, input: str):
])
chain = response_prompt | cls.llm

cache.add_message("response", HumanMessage(content=input))
def save_new_messages(ai_response):
cache.add_message("response", HumanMessage(content=input))
cache.add_message("response", AIMessage(content=ai_response))

return Streamable(chain.astream({ "thought": thought }, {"tags": ["response"], "metadata": {"conversation_id": cache.conversation_id, "user_id": cache.user_id}}), save_new_messages)

return Streamable(
chain.astream({ "thought": thought }, {"tags": ["response"], "metadata": {"conversation_id": cache.conversation_id, "user_id": cache.user_id}}),
lambda response: cache.add_message("response", AIMessage(content=response))
)

@classmethod
@sentry_sdk.trace
async def think_user_prediction(cls, cache: Conversation):
Expand Down Expand Up @@ -114,27 +115,45 @@ async def chat(cls, cache: Conversation, inp: str ) -> tuple[str, str]:
return thought, response



class Streamable:
"A async iterator wrapper for langchain streams that saves on completion via callback"

def __init__(self, iterator: AsyncIterator[BaseMessage], callback):
self.iterator = iterator
self.callback = callback
self.content = ""
self.stream_error = False

def __aiter__(self):
return self

async def __anext__(self):
try:
if self.stream_error:
raise StopAsyncIteration

data = await self.iterator.__anext__()
self.content += data.content
return data.content
except StopAsyncIteration as e:
self.callback(self.content)
raise StopAsyncIteration
except BadRequestError as e:
if e.code == "content_filter":
self.stream_error = True
self.message = "Sorry, your message was flagged as inappropriate. Please try again."

return self.message
else:
raise Exception(e)
except Exception as e:
raise e
sentry_sdk.capture_exception(e)

self.stream_error = True
self.message = "Sorry, an error occurred while streaming the response. Please try again."

return self.message

async def __call__(self):
async for _ in self:
Expand Down
6 changes: 5 additions & 1 deletion agent/mediator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
from dotenv import load_dotenv
# Supabase for Postgres Management
from supabase.client import create_client, Client
from supabase.lib.client_options import ClientOptions
from typing import List, Tuple, Dict
load_dotenv()

class SupabaseMediator:
@sentry_sdk.trace
def __init__(self):
self.supabase: Client = create_client(os.environ['SUPABASE_URL'], os.environ['SUPABASE_KEY'])
# Change the network db timeout to 60 seconds since the default is only 5 seconds
timeout_client_options = ClientOptions(postgrest_client_timeout=60)
self.supabase: Client = create_client(os.environ['SUPABASE_URL'], os.environ['SUPABASE_KEY'], timeout_client_options)

self.memory_table = os.environ["MEMORY_TABLE"]
self.conversation_table = os.environ["CONVERSATION_TABLE"]

Expand Down
Loading

0 comments on commit 3ae592e

Please sign in to comment.