Skip to content

Commit

Permalink
mistral working w/ RoleEnum.system error
Browse files Browse the repository at this point in the history
  • Loading branch information
vintrocode committed Dec 22, 2023
1 parent 479f1ba commit cc7975e
Show file tree
Hide file tree
Showing 11 changed files with 497 additions and 621 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ USER app

COPY agent/ agent/
COPY common/ common/
COPY bot/ bot/
# COPY bot/ bot/
COPY api/ api/

# https://stackoverflow.com/questions/29663459/python-app-does-not-print-anything-when-running-detached-in-docker
Expand Down
2 changes: 1 addition & 1 deletion agent/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .mediator import SupabaseMediator
import uuid
from typing import List, Tuple, Dict
from langchain.schema import BaseMessage
from langchain_core.messages import BaseMessage
import sentry_sdk

class Conversation:
Expand Down
19 changes: 11 additions & 8 deletions agent/chain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.prompts import (
from langchain_community.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain_core.prompts import (
SystemMessagePromptTemplate,
)
from langchain.prompts import load_prompt, ChatPromptTemplate
from langchain.schema import AIMessage, HumanMessage, BaseMessage
from langchain_core.prompts import load_prompt, ChatPromptTemplate
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
from langchain_mistralai.chat_models import ChatMistralAI

from openai import BadRequestError

Expand All @@ -25,9 +26,11 @@
class BloomChain:
"Wrapper class for encapsulating the multiple different chains used in reasoning for the tutor's thoughts"
# llm: ChatOpenAI = ChatOpenAI(model_name = "gpt-4", temperature=1.2)
llm: AzureChatOpenAI | ChatOpenAI
if (os.environ.get("OPENAI_API_TYPE") == "azure"):
llm: AzureChatOpenAI | ChatOpenAI | ChatMistralAI
if (os.environ["LLM_API"] == "azure"):
llm = AzureChatOpenAI(deployment_name = os.environ['OPENAI_API_DEPLOYMENT_NAME'], temperature=1.2, model_kwargs={"top_p": 0.5})
if (os.environ['LLM_API'] == "mistral"):
llm = ChatMistralAI(model="mistral-medium", mistral_api_key=os.environ['MISTRAL_API_KEY'], temperature=1, top_p=0.5)
else:
llm = ChatOpenAI(model_name = "gpt-4", temperature=1.2, model_kwargs={"top_p": 0.5})

Expand Down Expand Up @@ -58,8 +61,8 @@ def think(cls, cache: Conversation, input: str):
chain = thought_prompt | cls.llm

def save_new_messages(ai_response):
cache.add_message("response", HumanMessage(content=input))
cache.add_message("response", AIMessage(content=ai_response))
cache.add_message("thought", HumanMessage(content=input))
cache.add_message("thought", AIMessage(content=ai_response))

return Streamable(chain.astream({}, {"tags": ["thought"], "metadata": {"conversation_id": cache.conversation_id, "user_id": cache.user_id}}), save_new_messages)

Expand Down
4 changes: 2 additions & 2 deletions agent/mediator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
import uuid
import os
import sentry_sdk
Expand Down Expand Up @@ -35,7 +35,7 @@ def add_message(self, session_id: str, user_id: str, message_type: str, message:
"session_id": session_id,
"user_id": user_id,
"message_type": message_type,
"message": _message_to_dict(message)
"message": message_to_dict(message)
}
self.supabase.table(self.memory_table).insert(payload).execute()

Expand Down
16 changes: 8 additions & 8 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fastapi.middleware.cors import CORSMiddleware
# from fastapi.staticfiles import StaticFiles

from langchain.schema import _message_to_dict
from langchain_core.messages import message_to_dict
import sentry_sdk

import os
Expand All @@ -20,12 +20,12 @@
load_dotenv()


rate = 0.2 if os.getenv("SENTRY_ENVIRONMENT") == "production" else 1.0
sentry_sdk.init(
dsn=os.environ['SENTRY_DSN_API'],
traces_sample_rate=rate,
profiles_sample_rate=rate
)
# rate = 0.2 if os.getenv("SENTRY_ENVIRONMENT") == "production" else 1.0
# sentry_sdk.init(
# dsn=os.environ['SENTRY_DSN_API'],
# traces_sample_rate=rate,
# profiles_sample_rate=rate
# )

app = FastAPI()

Expand Down Expand Up @@ -102,7 +102,7 @@ async def update_conversations(change: ConversationDefinition):
async def get_messages(user_id: str, conversation_id: str):
async with LOCK:
messages = MEDIATOR.messages(user_id=user_id, session_id=conversation_id, message_type="response", limit=(False, None))
converted_messages = [_message_to_dict(_message) for _message in messages]
converted_messages = [message_to_dict(_message) for _message in messages]
return {
"messages": converted_messages
}
Expand Down
Empty file removed bot/__init__.py
Empty file.
31 changes: 0 additions & 31 deletions bot/app.py

This file was deleted.

180 changes: 0 additions & 180 deletions bot/core.py

This file was deleted.

Loading

0 comments on commit cc7975e

Please sign in to comment.