Skip to content

Commit

Permalink
updating-openai
Browse files Browse the repository at this point in the history
  • Loading branch information
ieaves committed Dec 8, 2023
1 parent 6d792df commit 213892c
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 556 deletions.
198 changes: 109 additions & 89 deletions grai-server/app/grAI/chat_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Annotated, Any, Callable, Literal, ParamSpec, Type, TypeVar, Union
import itertools
import openai
from django.conf import settings

from pgvector.django import MaxInnerProduct
import tiktoken
Expand All @@ -24,6 +25,7 @@
from lineage.models import Edge, Node, NodeEmbeddings
from workspaces.models import Workspace
from django.db.models.expressions import RawSQL
from openai.types.chat.chat_completion_message import ChatCompletionMessage

logging.basicConfig(level=logging.DEBUG)

Expand Down Expand Up @@ -67,22 +69,23 @@ class AIMessage(BaseMessage):


class FunctionMessage(BaseMessage):
role: Literal["function"] = "function"
role: Literal["tool"] = "tool"
name: str
tool_call_id: str

def representation(self) -> dict:
return {"role": self.role, "content": self.content, "name": self.name}
return {"tool_call_id": self.tool_call_id, "role": self.role, "content": self.content, "name": self.name}


class ChatMessage(BaseModel):
message: Union[UserMessage, SystemMessage, AIMessage, FunctionMessage]
message: Union[UserMessage, SystemMessage, AIMessage, FunctionMessage, ChatCompletionMessage]


class ChatMessages(BaseModel):
messages: list[BaseMessage]
messages: list[BaseMessage | ChatCompletionMessage]

def to_gpt(self) -> list[dict]:
return [message.representation() for message in self.messages]
return [message.representation() if isinstance(message, BaseMessage) else message for message in self.messages]

def __getitem__(self, index):
return self.messages[index]
Expand All @@ -96,9 +99,6 @@ def append(self, item):
def extend(self, items):
self.messages.extend(items)

def token_length(self) -> int:
return sum(m.token_length for m in self.messages)


def get_token_limit(model_type: str) -> int:
OPENAI_TOKEN_LIMITS = {"gpt-4": 8192, "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-16k": 16385, "gpt-4-32k": 32768}
Expand Down Expand Up @@ -143,7 +143,10 @@ async def response(self, **kwargs) -> str:
return result

def gpt_definition(self) -> dict:
return {"name": self.id, "description": self.description, "parameters": self.schema_model.schema()}
return {
"type": "function",
"function": {"name": self.id, "description": self.description, "parameters": self.schema_model.schema()},
}


class NodeIdentifier(BaseModel):
Expand Down Expand Up @@ -324,7 +327,7 @@ def response_message(result_set: list[str]) -> str | None:
if total_results == 0:
message = "No results found matching these query conditions."
else:
message = "Results are returned in the following format: (source.name, source.namespace) -> (destination.name, destination.namespace)"
message = "Results are returned in the following format: (source name, source namespace) -> (destination name, destination namespace)"

return message

Expand Down Expand Up @@ -382,23 +385,25 @@ class EmbeddingSearchAPI(API):
description: str = "Search for nodes which match any query."
schema_model = EmbeddingSearchSchema

def __init__(self, workspace: str | uuid.UUID):
def __init__(self, workspace: str | uuid.UUID, client: openai.AsyncOpenAI):
self.workspace = workspace
self.client = client
self.model_type = "text-embedding-ada-002"
self.token_limit = 4000
self.token_limit = 8000
self.encoder = tiktoken.encoding_for_model(self.model_type)

@staticmethod
def response_message(result_set: list[str]) -> str | None:
def response_message(result_set: list) -> str:
total_results = len(result_set)
if total_results == 0:
message = "No results found matching these query conditions."
else:
message = "Results are returned in the following format: (source.name, source.namespace) -> (destination.name, destination.namespace)"
message = "Results are returned in the following format: (source name, source namespace) -> (destination name, destination namespace)"

return message

def nearest_neighbor_search(self, vector_query: list[int], limit=10) -> list[Node]:
@database_sync_to_async
def nearest_neighbor_search(self, vector_query: list[int], limit=10) -> tuple[list[Node], str]:
node_result = (
NodeEmbeddings.objects.filter(node__workspace__id=self.workspace)
.order_by(MaxInnerProduct("embedding", vector_query))
Expand All @@ -407,17 +412,16 @@ def nearest_neighbor_search(self, vector_query: list[int], limit=10) -> list[Nod

return [n.node for n in node_result]

@database_sync_to_async
def call(self, **kwargs):
async def call(self, **kwargs):
try:
inp = self.schema_model(**kwargs)
except:
return [], self.response_message([])

search_term = self.encoder.decode(self.encoder.encode(inp.search_term)[: self.token_limit])
embedding_resp = openai.Embedding.create(input=search_term, model=self.model_type)
embedding_resp = await self.client.embeddings.create(input=search_term, model=self.model_type)
embedding = list(embedding_resp.data[0].embedding)
neighbors = self.nearest_neighbor_search(embedding, inp.limit)
neighbors = await self.nearest_neighbor_search(embedding, inp.limit)
response = [(n.name, n.namespace) for n in neighbors]
return response, self.response_message(response)

Expand Down Expand Up @@ -457,6 +461,7 @@ def __init__(
self,
chat_id: str,
prompt: str,
client: openai.AsyncOpenAI | None = None,
model_type: str = settings.OPENAI_PREFERRED_MODEL,
user: str = str(uuid.uuid4()),
functions: list = None,
Expand All @@ -467,6 +472,7 @@ def __init__(

self.model_type = model_type
self.token_limit = get_token_limit(self.model_type)
self.model_limit = self.token_limit * 0.9
self.encoder = tiktoken.encoding_for_model(self.model_type)
# self.encoder = FakeEncoder()
self.chat_id = chat_id
Expand All @@ -478,6 +484,10 @@ def __init__(

self.prompt_message = self.build_message(SystemMessage, content=self.system_context)

if client is None:
client = openai.AsyncOpenAI(api_key=settings.OPENAI_API_KEY, organization=settings.OPENAI_ORG_ID)
self.client = client

def build_message(self, message_type: Type[T], content: str, **kwargs) -> T:
return message_type(content=content, token_length=len(self.encoder.encode(content)), **kwargs)

Expand Down Expand Up @@ -507,15 +517,13 @@ def hydrate_chat(self):
async def summarize(self, messages: list[BaseMessage]) -> AIMessage:
summary_prompt = """
Please summarize this conversation encoding the most important information a future agent would need to continue
working on the problem with me. Please insure you do not call any functions providing an exclusively
working on the problem with me. Please insure your response is a
text based summary of the conversation to this point with all relevant context for the next agent.
"""
message = self.build_message(UserMessage, summary_prompt)
summary_messages = ChatMessages(messages=[*messages, message])
summary_message = SystemMessage(content=summary_prompt)
summary_messages = ChatMessages(messages=[self.prompt_message, *messages, summary_message])
logging.info(f"Summarizing conversation for chat: {self.chat_id}")
response = await openai.ChatCompletion.acreate(
model=self.model_type, user=self.user, messages=summary_messages.to_gpt()
)
response = await self.client.chat.completions.create(model=self.model_type, messages=summary_messages.to_gpt())

# this is hacky for now
summary_message = self.build_message(AIMessage, response.choices[0].message.content)
Expand All @@ -525,36 +533,45 @@ def functions(self):
return [func.gpt_definition() for func in self.api_functions.values()]

@property
def model(self) -> Callable[P, R]:
model = partial(openai.ChatCompletion.acreate, model=self.model_type, user=self.user)

def model(self) -> R:
base_kwargs = {"model": self.model_type}
if len(functions := self.functions()) > 0:
model = partial(model, functions=functions)
base_kwargs |= {"tools": functions, "tool_choice": "auto"}

def inner(**kwargs):
messages = kwargs.pop("messages", [])
messages = [self.prompt_message.representation(), *messages]
return self.client.chat.completions.create(
messages=messages,
**base_kwargs,
**kwargs,
)

return model
return inner

async def evaluate_summary(self, messages: ChatMessages) -> ChatMessages:
model_limit = int(self.token_limit * 0.85) # this needs to be calculated from the openai response

requires_summary = messages.token_length() > model_limit
while requires_summary:
prev_accumulated_tokens = 0
accumulated_tokens = 0
while True:
prev_accumulated_tokens = self.prompt_message.token_length
accumulated_tokens = self.prompt_message.token_length
i = 0
for i, message in enumerate(messages.messages):
accumulated_tokens += message.token_length
if accumulated_tokens > model_limit:
if hasattr(message, "token_length"):
accumulated_tokens += message.token_length
elif hasattr(message, "content") and message.content is not None:
accumulated_tokens += len(self.encoder.encode(message.content))

if accumulated_tokens > self.model_limit:
break
else:
prev_accumulated_tokens = accumulated_tokens

available_tokens = model_limit - prev_accumulated_tokens
available_tokens = self.model_limit - prev_accumulated_tokens
message = messages.messages[i]
if i == len(messages) and accumulated_tokens < model_limit:
requires_summary = False
if i == len(messages) and accumulated_tokens < self.model_limit:
break
elif available_tokens >= message.token_length:
summary = await self.summarize(messages.messages[: (i + 1)])
messages = [self.prompt_message, summary, *messages.messages[(i + 1) :]]
messages = [summary, *messages.messages[(i + 1) :]]
else:
encoding = self.encoder.encode(message.content)

Expand All @@ -567,88 +584,91 @@ async def evaluate_summary(self, messages: ChatMessages) -> ChatMessages:
next_message_obj.token_length = len(encoding[available_tokens:])

summary = await self.summarize([*messages.messages[:i], message_obj])
messages = [self.prompt_message, summary, next_message_obj, *messages.messages[(i + 1) :]]
messages = [summary, next_message_obj, *messages.messages[(i + 1) :]]

messages = ChatMessages(messages=messages)
return messages

async def request(self, user_input: str) -> str:
logging.info(f"Responding to request for: {self.chat_id}")

messages = ChatMessages(
messages=[self.prompt_message, *self.cached_messages, self.build_message(UserMessage, content=user_input)]
)
user_message = self.build_message(UserMessage, content=user_input)
messages = ChatMessages(messages=[*self.cached_messages, user_message])

final_response: str | None = None
usage = 0

while not final_response:
new_messages = []
if usage > self.token_limit:
messages = await self.evaluate_summary(messages)

result = None
stop = False
n = 0
while not stop:
messages = await self.evaluate_summary(messages)
response = await self.model(messages=messages.to_gpt())

if result:
for choice in response.choices:
if result != choice:
result = choice
break
else:
result = response.choices[0]

if stop := result.finish_reason == "stop":
message = self.build_message(AIMessage, content=result.message.content)
messages.append(message)
elif result.finish_reason == "function_call":
func_id = result.message.function_call.name
func_kwargs = json.loads(result.message.function_call.arguments)
api = self.api_functions.get(func_id, InvalidAPI(self.api_functions.values()))
response = await api.response(**func_kwargs)

if isinstance(api, InvalidAPI):
message = self.build_message(SystemMessage, response)
messages.append(message)
else:
message = self.build_message(FunctionMessage, content=response, name=func_id)
messages.append(message)
elif result.finish_reason == "length":
summary = await self.summarize(messages[:-1])
messages = ChatMessages(messages=[self.prompt_message, summary, messages[-1]])
else:
# valid values include length, content_filter, null
raise NotImplementedError(f"No stop reason for {result.finish_reason}")
usage = response.usage.total_tokens
response_choice = response.choices[0]
response_message = response_choice.message
messages.append(response_message)

if finish_reason := response_choice.finish_reason == "stop":
final_response = response_message.content
elif finish_reason == "length":
summary = await self.summarize(messages[:-1])
messages = ChatMessages(messages=[summary, messages[-1]])
elif response_message.tool_calls:
new_messages = []
for tool_call in response_message.tool_calls:
func_id = tool_call.function.name
func_kwargs = json.loads(tool_call.function.arguments)
api = self.api_functions.get(func_id, InvalidAPI(self.api_functions.values()))
response = await api.response(**func_kwargs)

if isinstance(api, InvalidAPI):
message = self.build_message(SystemMessage, response)
new_messages.append(message)
else:
message = self.build_message(
FunctionMessage, content=response, name=func_id, tool_call_id=tool_call.id
)
new_messages.append(message)

usage += sum([m.token_length for m in new_messages])
messages.extend(new_messages)
self.cached_messages = messages.messages
return result.message.content
return final_response


async def get_chat_conversation(
chat_id: str | uuid.UUID, workspace: str | uuid.UUID, model_type: str = settings.OPENAI_PREFERRED_MODEL
):
chat_prompt = """
You are a helpful assistant with domain expertise about an organizations data and data infrastructure.
Before you can help the user, you need to understand the context of their request and what they are trying to accomplish.
You should attempt to understand the context of the request and what the user is trying to accomplish.
If a user asks about specific data like nodes and you're unable to find an answer you should attempt to find a similar node and explain why you think it's similar.
Your responses must use Markdown syntax.
* You know how to query for additional context and metadata about any data in the organization.
Data Structure Notes:
* Unique pieces of data like a column in a database is identified by a (name, namespace) tuple or a unique uuid.
* You can help users discover new data or identify and correct issues such as broken data pipelines, and BI dashboards.
* Your responses must use Markdown syntax
* You are proactive in looking up additional data to answer any user question.
* Attempt to explain your reasoning in each answer.
* Nodes contain a metadata field with extra context about the node.
* Nodes and Edges are typed. You can identify the type under `metadata.grai.node_type` or `metadata.grai.edge_type`
* If a Node has a type like `Column` with a `TableToColumn` Edge connecting to a `Table` node, the Column node represents a column in the table.
* Node names for databases and datawarehouses are constructed following `{schema}.{table}.{column}` format e.g. a column named `id` in a table named `users` in a schema named `public` would be identified as `public.users.id`
* If a user asks about a specific node and you're unable to find an answer you should attempt to find a similar node and explain why you think it's similar.
"""
client = openai.AsyncOpenAI(api_key=settings.OPENAI_API_KEY, organization=settings.OPENAI_ORG_ID)

functions = [
NodeLookupAPI(workspace=workspace),
# Todo: edge lookup is broken
# EdgeLookupAPI(workspace=workspace),
# FuzzyMatchNodesAPI(workspace=workspace),
EmbeddingSearchAPI(workspace=workspace),
EmbeddingSearchAPI(workspace=workspace, client=client),
NHopQueryAPI(workspace=workspace),
]

conversation = BaseConversation(
prompt=chat_prompt, model_type=model_type, functions=functions, chat_id=str(chat_id)
prompt=chat_prompt, model_type=model_type, functions=functions, chat_id=str(chat_id), client=client
)
await conversation.hydrate_chat()
return conversation
2 changes: 1 addition & 1 deletion grai-server/app/grAI/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ def get_max_length_content(self, content: str) -> str:

def get_embedding(self, content: str) -> R:
content = self.get_max_length_content(content)
return openai.Embedding.create(input=content, model=self.model)
return openai.embedding.create(input=content, model=self.model)
Loading

0 comments on commit 213892c

Please sign in to comment.