diff --git a/grai-server/app/grAI/chat_implementations.py b/grai-server/app/grAI/chat_implementations.py index a23fa0512..7cb74bc9d 100644 --- a/grai-server/app/grAI/chat_implementations.py +++ b/grai-server/app/grAI/chat_implementations.py @@ -1,31 +1,37 @@ import copy import json import logging -import operator import uuid -from abc import ABC, abstractmethod -from functools import cached_property, partial, reduce -from itertools import accumulate -from typing import Annotated, Any, Callable, Literal, ParamSpec, Type, TypeVar, Union -import itertools -import openai -from django.conf import settings -from pgvector.django import MaxInnerProduct +from typing import Annotated, Any, Callable, Literal, ParamSpec, Type, TypeVar, Union, Coroutine + +from openai import AsyncStream +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.completion_usage import CompletionUsage +import openai +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from workspaces.models import Workspace import tiktoken from django.conf import settings from django.core.cache import cache -from django.db.models import Q -from grai_schemas.serializers import GraiYamlSerializer -from pydantic import BaseModel, Field +from itertools import islice +import asyncio from channels.db import database_sync_to_async -from connections.adapters.schemas import model_to_schema + from grAI.models import Message -from lineage.models import Edge, Node, NodeEmbeddings -from workspaces.models import Workspace -from django.db.models.expressions import RawSQL -from openai.types.chat.chat_completion_message import ChatCompletionMessage +from grAI.chat_types import ( + BaseMessage, + UserMessage, + SystemMessage, + AIMessage, + FunctionMessage, + ChatMessages, + UsageMessage, + ChatMessage, + to_gpt, +) +from grAI.tools import NodeLookupAPI, EdgeLookupAPI, FuzzyMatchNodesAPI, EmbeddingSearchAPI, NHopQueryAPI, InvalidAPI logging.basicConfig(level=logging.DEBUG) @@ -36,68 +42,11 @@ R = TypeVar("R") P = ParamSpec("P") -RoleType = Union[Literal["user"], Literal["system"], Literal["assistant"]] - - -class BaseMessage(BaseModel): - role: str - content: str - token_length: int - - def representation(self) -> dict: - return {"role": self.role, "content": self.content} - - def chunk_content(self, n_chunks: int = 2) -> list[str]: - if self.token_length is None: - raise ValueError("Cannot chunk content without a token length") - - chunk_size = self.token_length // n_chunks - chunks = [self.content[i : i + chunk_size] for i in range(0, len(self.content), chunk_size)] - return chunks - - -class UserMessage(BaseMessage): - role: Literal["user"] = "user" - - -class SystemMessage(BaseMessage): - role: Literal["system"] = "system" - - -class AIMessage(BaseMessage): - role: Literal["assistant"] = "assistant" - - -class FunctionMessage(BaseMessage): - role: Literal["tool"] = "tool" - name: str - tool_call_id: str - - def representation(self) -> dict: - return {"tool_call_id": self.tool_call_id, "role": self.role, "content": self.content, "name": self.name} - - -class ChatMessage(BaseModel): - message: Union[UserMessage, SystemMessage, AIMessage, FunctionMessage, ChatCompletionMessage] - -class ChatMessages(BaseModel): - messages: list[BaseMessage | ChatCompletionMessage] - - def to_gpt(self) -> list[dict]: - return [message.representation() if isinstance(message, BaseMessage) else message for message in self.messages] - - def __getitem__(self, index): - return self.messages[index] - - def __len__(self) -> int: - return len(self.messages) - - def append(self, item): - self.messages.append(item) - - def extend(self, items): - self.messages.extend(items) +def chunker(it, size): + iterator = iter(it) + while chunk := list(islice(iterator, size)): + yield chunk def get_token_limit(model_type: str) -> int: @@ -115,345 +64,15 @@ def get_token_limit(model_type: str) -> int: return 2049 -class API(ABC): - schema_model: BaseModel - description: str - id: str - - @abstractmethod - def call(self, **kwargs) -> (Any, str): - pass - - def serialize(self, result) -> str: - if isinstance(result, str): - return result - - return GraiYamlSerializer.dump(result) - - async def response(self, **kwargs) -> str: - logging.info(f"Calling {self.id} with {kwargs}") - obj, message = await self.call(**kwargs) - - logging.info(f"Building Response message for {self.id} with {kwargs}") - if message is None: - result = self.serialize(obj) - else: - result = f"{self.serialize(obj)}\n{message}" - - return result - - def gpt_definition(self) -> dict: - return { - "type": "function", - "function": {"name": self.id, "description": self.description, "parameters": self.schema_model.schema()}, - } - - -class NodeIdentifier(BaseModel): - name: str = Field(description="The name of the node to query for") - namespace: str = Field(description="The namespace of the node to query for") - - -class NodeLookup(BaseModel): - nodes: list[NodeIdentifier] = Field(description="A list of nodes to lookup") - - -class NodeLookupAPI(API): - id = "node_lookup" - description = "Lookup metadata about one or more nodes if you know precisely which node(s) to lookup" - schema_model = NodeLookup - - def __init__(self, workspace: str | uuid.UUID): - self.workspace = workspace - self.query_limit = MAX_RETURN_LIMIT - - @staticmethod - def response_message(result_set: list[Node]) -> str | None: - total_results = len(result_set) - if total_results == 0: - message = "No results found matching these query conditions." - else: - message = None - - return message - - @database_sync_to_async - def call(self, **kwargs) -> (list[Node], str | None): - try: - validation = self.schema_model(**kwargs) - except: - return [], "Invalid input. Please check your input and try again." - q_objects = (Q(**node.dict(exclude_none=True)) for node in validation.nodes) - query = reduce(operator.or_, q_objects) - result_set = Node.objects.filter(workspace=self.workspace).filter(query).order_by("-created_at").all() - response_items = model_to_schema(result_set[: self.query_limit], "NodeV1") - return response_items, self.response_message(result_set) - - -class FuzzyMatchQuery(BaseModel): - string: str = Field(description="The fuzzy string used to search amongst node names") - - -class FuzzyMatchNodesAPI(API): - id = "node_fuzzy_lookup" - description = "Performs a fuzzy search for nodes matching a name regardless of namespace" - schema_model = FuzzyMatchQuery - - def __init__(self, workspace: str | uuid.UUID): - self.workspace = workspace - self.query_limit = MAX_RETURN_LIMIT - - @staticmethod - def response_message(result_set: list[Node]) -> str | None: - total_results = len(result_set) - if total_results == 0: - message = "No results found matching these query conditions." - else: - message = None - - return message - - @database_sync_to_async - def call(self, string: str) -> (list, str | None): - result_set = ( - Node.objects.filter(workspace=self.workspace).filter(name__contains=string).order_by("-created_at").all() - ) - response_items = [{"name": node.name, "namespace": node.namespace} for node in result_set] - - return response_items, self.response_message(result_set) - - -class EdgeLookupSchema(BaseModel): - source: uuid.UUID | None = Field(description="The primary key of the source node on an edge", default=None) - destination: uuid.UUID | None = Field( - description="The primary key of the destination node on an edge", default=None - ) - - -class MultiEdgeLookup(BaseModel): - edges: list[EdgeLookupSchema] = Field( - description="List of edges to lookup. Edges can be uniquely identified by a (name, namespace) tuple, or by a (source, destination) tuple of the nodes the edge connects" - ) - - -class EdgeLookupAPI(API): - id = "edge_lookup" - description = """ - This function Supports looking up edges from a data lineage graph. For example, a query with name=Test but no - namespace value will return all edges explicitly named "Test" regardless of namespace. - Edges are uniquely identified both by their (name, namespace), and by the (source, destination) nodes they connect. - """ - schema_model = MultiEdgeLookup - - def __init__(self, workspace: str | uuid.UUID): - self.workspace = workspace - self.query_limit = MAX_RETURN_LIMIT - - @staticmethod - def response_message(result_set: list[Edge]) -> str | None: - total_results = len(result_set) - if total_results == 0: - message = "No results found matching these query conditions." - else: - message = None - - return message - - @database_sync_to_async - def call(self, **kwargs) -> (list[Edge], str | None): - validation = self.schema_model(**kwargs) - q_objects = (Q(**node.dict(exclude_none=True)) for node in validation.edges) - query = reduce(operator.or_, q_objects) - result_set = Edge.objects.filter(workspace=self.workspace).filter(query).all()[: self.query_limit] - return model_to_schema(result_set[: self.query_limit], "EdgeV1"), self.response_message(result_set) - - -class EdgeFuzzyLookupSchema(BaseModel): - name__contains: str | None = Field( - description="The name of the edge to lookup perform a fuzzy search on", default=None - ) - namespace__contains: str | None = Field( - description="The namespace of the edge to lookup perform a fuzzy search on", default=None - ) - is_active: bool | None = Field(description="Whether or not the edge is active", default=True) - - -class MultiFuzzyEdgeLookup(BaseModel): - edges: list[EdgeLookupSchema] = Field( - description="List of edges to lookup. Edges can be uniquely identified by a (name, namespace) tuple, or by a (source, destination) tuple of the nodes the edge connects" - ) - - -class EdgeFuzzyLookupAPI(EdgeLookupAPI): - id = "edge_fuzzy_lookup" - description = """ - This function Supports looking up edges from a data lineage graph. For example, a query with name__contains=test - but no namespace value will return all edges whose names contain the substring "test" regardless of namespace. - Edges are uniquely identified both by their (name, namespace), and by the (source, destination) nodes they connect. - """ - schema_model = MultiFuzzyEdgeLookup - - -class NodeEdgeSerializer: - def __init__(self, nodes, edges): - self.nodes = nodes - self.edges = edges - - def representation(self, path=None): - items = [item.spec for item in (*self.nodes, *self.edges)] - return GraiYamlSerializer.dump(items, path) - - def __str__(self): - return self.representation() - - -class NHopQuerySchema(BaseModel): - name: str = Field(description="The name of the node to query for") - namespace: str = Field(description="The namespace of the node to query for") - n: int = Field(description="The number of hops to query for", default=1) - - -class NHopQueryAPI(API): - id: str = "n_hop_query" - description: str = "query for nodes and edges within a specified number of hops from a given node" - schema_model = NHopQuerySchema - - def __init__(self, workspace: str | uuid.UUID): - self.workspace = workspace - - @staticmethod - def response_message(result_set: list[str]) -> str | None: - total_results = len(result_set) - if total_results == 0: - message = "No results found matching these query conditions." - else: - message = "Results are returned in the following format: (source name, source namespace) -> (destination name, destination namespace)" - - return message - - @staticmethod - def filter(queryset: list[Edge], source_nodes: list[Node], dest_nodes: list[Node]) -> tuple[list[Node], list[Node]]: - def get_id(node: Node) -> tuple[str, str]: - return node.name, node.namespace - - source_ids: set[T] = {get_id(node) for node in source_nodes} - dest_ids: set[T] = {get_id(node) for node in dest_nodes} - query_hashes: set[tuple[T, T]] = {(get_id(node.source), get_id(node.destination)) for node in queryset} - - source_resp = [n.destination for n, hashes in zip(queryset, query_hashes) if hashes[0] in source_ids] - dest_resp = [n.source for n, hashes in zip(queryset, query_hashes) if hashes[1] in dest_ids] - return source_resp, dest_resp - - @database_sync_to_async - def call(self, **kwargs) -> (list[Edge | Node], str | None): - def edge_label(edge: Edge): - return f"({edge.source.name}, {edge.source.namespace}) -> ({edge.destination.name}, {edge.destination.namespace})" - - try: - inp = self.schema_model(**kwargs) - except Exception as e: - return [], f"Invalid input: {e}" - - source_node = Node.objects.filter(workspace__id=self.workspace, name=inp.name, namespace=inp.namespace).first() - if source_node is None: - return [], self.response_message([]) - - source_nodes = [source_node] - dest_nodes = [source_node] - - return_edges = [] - for i in range(inp.n): - query = Q(source__in=source_nodes) | Q(destination__in=dest_nodes) - edges = Edge.objects.filter(query).select_related("source", "destination").all() - - source_nodes, dest_nodes = self.filter(edges, source_nodes, dest_nodes) - return_edges.extend((edge_label(item) for item in edges)) - - if len(source_nodes) == 0 and len(dest_nodes) == 0: - break - - return return_edges, self.response_message(return_edges) - - -class EmbeddingSearchSchema(BaseModel): - search_term: str = Field(description="Search request string") - limit: int = Field(description="number of results to return", default=10) - - -class EmbeddingSearchAPI(API): - id: str = "embedding_search_api" - description: str = "Search for nodes which match any query." - schema_model = EmbeddingSearchSchema - - def __init__(self, workspace: str | uuid.UUID, client: openai.AsyncOpenAI): - self.workspace = workspace - self.client = client - self.model_type = "text-embedding-ada-002" - self.token_limit = 8000 - self.encoder = tiktoken.encoding_for_model(self.model_type) - - @staticmethod - def response_message(result_set: list) -> str: - total_results = len(result_set) - if total_results == 0: - message = "No results found matching these query conditions." - else: - message = "Results are returned in the following format: (source name, source namespace) -> (destination name, destination namespace)" - - return message - - @database_sync_to_async - def nearest_neighbor_search(self, vector_query: list[int], limit=10) -> tuple[list[Node], str]: - node_result = ( - NodeEmbeddings.objects.filter(node__workspace__id=self.workspace) - .order_by(MaxInnerProduct("embedding", vector_query)) - .select_related("node")[:limit] - ) - - return [n.node for n in node_result] - - async def call(self, **kwargs): - try: - inp = self.schema_model(**kwargs) - except: - return [], self.response_message([]) - - search_term = self.encoder.decode(self.encoder.encode(inp.search_term)[: self.token_limit]) - embedding_resp = await self.client.embeddings.create(input=search_term, model=self.model_type) - embedding = list(embedding_resp.data[0].embedding) - neighbors = await self.nearest_neighbor_search(embedding, inp.limit) - response = [(n.name, n.namespace) for n in neighbors] - return response, self.response_message(response) - - -class InvalidApiSchema(BaseModel): - pass - - -class InvalidAPI(API): - id = "invalid_api_endpoint" - description = "placeholder" - schema_model = InvalidApiSchema - - def __init__(self, apis): - self.function_string = ", ".join([f"`{api.id}`" for api in apis]) - - @database_sync_to_async - def call(self, *args, **kwargs): - return f"Invalid API Endpoint. That function does not exist. The supported apis are {self.function_string}" - - def serialize(self, result): - return result - - -class FakeEncoder: - def encode(self, text): - return [1, 2, 3, 4] - - -def pre_compute_graph(workspace: str | uuid.UUID): - query_filter = Q(workspace=workspace) & Q(is_active=True) - edges = Edge.objects.filter(query_filter).all() +class SummaryPrompt: + def __init__(self, encoder: tiktoken.Encoding): + prompt_str = """ + Please summarize this conversation encoding the most important information a future agent would need to continue + working on the problem with me. Please insure your response is a + text based summary of the conversation to this point with all relevant context for the next agent. + """ + self.prompt = SystemMessage(content=prompt_str) + self.token_usage = len(encoder.encode(prompt_str)) class BaseConversation: @@ -463,7 +82,6 @@ def __init__( prompt: str, client: openai.AsyncOpenAI | None = None, model_type: str = settings.OPENAI_PREFERRED_MODEL, - user: str = str(uuid.uuid4()), functions: list = None, verbose=False, ): @@ -472,65 +90,90 @@ def __init__( self.model_type = model_type self.token_limit = get_token_limit(self.model_type) - self.model_limit = self.token_limit * 0.9 + + # Margin reserved for model response + self.model_limit = int(self.token_limit * 0.8) + self.encoder = tiktoken.encoding_for_model(self.model_type) # self.encoder = FakeEncoder() self.chat_id = chat_id self.cache_id = f"grAI:chat_id:{chat_id}" self.system_context = prompt - self.user = user self.api_functions = {func.id: func for func in functions} + self.invalid_api = InvalidAPI(self.api_functions.values()) + self.verbose = verbose - self.prompt_message = self.build_message(SystemMessage, content=self.system_context) + self.summary_prompt: SummaryPrompt = SummaryPrompt(self.encoder) + + # I don't have a clean way to identify the initial token allocation for the tool prompt + # without simply calling the api and seeing what the usage is. + # https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/40 + tool_prompt_token_allocation = 800 + self.baseline_usage: CompletionUsage = CompletionUsage( + completion_tokens=0, prompt_tokens=tool_prompt_token_allocation, total_tokens=tool_prompt_token_allocation + ) + self.prompt_message: UsageMessage = UsageMessage( + message=SystemMessage(content=self.system_context), + usage=self.baseline_usage, + ) if client is None: client = openai.AsyncOpenAI(api_key=settings.OPENAI_API_KEY, organization=settings.OPENAI_ORG_ID) self.client = client - def build_message(self, message_type: Type[T], content: str, **kwargs) -> T: - return message_type(content=content, token_length=len(self.encoder.encode(content)), **kwargs) + def functions(self): + return [func.gpt_definition() for func in self.api_functions.values()] + + def build_message( + self, message_type: Type[T], content: str, current_usage: CompletionUsage, **kwargs + ) -> UsageMessage: + message = message_type(content=content, **kwargs) + encoding = self.encoder.encode(content) + usage = CompletionUsage( + completion_tokens=len(encoding), + prompt_tokens=current_usage.total_tokens, + total_tokens=current_usage.total_tokens + len(encoding), + ) + return UsageMessage(message=message, usage=usage, encoding=encoding) @property - def cached_messages(self) -> list[BaseMessage]: - messages = [ChatMessage(message=message).message for message in cache.get(self.cache_id)] - return messages + def cached_messages(self) -> ChatMessages: + return ChatMessages(messages=cache.get(self.cache_id)) @cached_messages.setter - def cached_messages(self, values: list[BaseMessage]): - cache.set(self.cache_id, [v.dict() for v in values]) + def cached_messages(self, values: ChatMessages): + cache.set(self.cache_id, [v.dict(exclude_none=True) for v in values.messages]) @database_sync_to_async def hydrate_chat(self): + """ + Hydration doesn't currently capture function call context or summarization and will need to be updated to do so. + """ logging.info(f"Hydrating chat history for conversations: {self.chat_id}") messages = cache.get(self.cache_id, None) if messages is None: logging.info(f"Loading chat history for chat {self.chat_id} from database") messages_iter = ( - {"role": m.role, "content": m.message, "token_length": len(self.encoder.encode(m.message))} + {"role": m.role, "content": m.message} for m in Message.objects.filter(chat_id=self.chat_id).order_by("-created_at").all() ) - messages_list = [ChatMessage(message=message).message for message in messages_iter] - self.cached_messages = messages_list - async def summarize(self, messages: list[BaseMessage]) -> AIMessage: - summary_prompt = """ - Please summarize this conversation encoding the most important information a future agent would need to continue - working on the problem with me. Please insure your response is a - text based summary of the conversation to this point with all relevant context for the next agent. - """ - summary_message = SystemMessage(content=summary_prompt) - summary_messages = ChatMessages(messages=[self.prompt_message, *messages, summary_message]) - logging.info(f"Summarizing conversation for chat: {self.chat_id}") - response = await self.client.chat.completions.create(model=self.model_type, messages=summary_messages.to_gpt()) - - # this is hacky for now - summary_message = self.build_message(AIMessage, response.choices[0].message.content) - return summary_message - - def functions(self): - return [func.gpt_definition() for func in self.api_functions.values()] + # default tokens should be computed based on the initial prompts this might be an issue in the future + usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) + chat_messages = ChatMessages(messages=[]) + for item in messages_iter: + encoding = self.encoder.encode(item["content"]) + usage = CompletionUsage( + completion_tokens=len(encoding), + prompt_tokens=usage.total_tokens, + total_tokens=usage.total_tokens + len(encoding), + ) + usage_message = UsageMessage(message=item, usage=usage, encoding=encoding) + chat_messages.append(usage_message) + + self.cached_messages = chat_messages @property def model(self) -> R: @@ -538,9 +181,12 @@ def model(self) -> R: if len(functions := self.functions()) > 0: base_kwargs |= {"tools": functions, "tool_choice": "auto"} - def inner(**kwargs): - messages = kwargs.pop("messages", []) - messages = [self.prompt_message.representation(), *messages] + def inner(messages: list = None, **kwargs) -> Coroutine[Any, Any, ChatCompletion]: + if messages is None: + messages = [self.prompt_message.message.representation()] + else: + messages = [self.prompt_message.message.representation(), *messages] + return self.client.chat.completions.create( messages=messages, **base_kwargs, @@ -549,93 +195,183 @@ def inner(**kwargs): return inner + async def summarize(self, messages: list[UsageMessage]) -> UsageMessage: + summary_messages = [self.prompt_message, *messages, self.summary_prompt.prompt] + + logging.info(f"Summarizing conversation for chat: {self.chat_id}") + response = await self.client.chat.completions.create(model=self.model_type, messages=to_gpt(summary_messages)) + + summary_usage = CompletionUsage( + completion_tokens=response.usage.completion_tokens, + prompt_tokens=self.prompt_message.usage.total_tokens, + total_tokens=response.usage.total_tokens + self.prompt_message.usage.total_tokens, + ) + + return UsageMessage( + usage=summary_usage, + message=SystemMessage(content=response.choices[0].message.content), + encoding=self.encoder.encode(response.choices[0].message.content), + ) + async def evaluate_summary(self, messages: ChatMessages) -> ChatMessages: while True: - prev_accumulated_tokens = self.prompt_message.token_length - accumulated_tokens = self.prompt_message.token_length - i = 0 - for i, message in enumerate(messages.messages): - if hasattr(message, "token_length"): - accumulated_tokens += message.token_length - elif hasattr(message, "content") and message.content is not None: - accumulated_tokens += len(self.encoder.encode(message.content)) - - if accumulated_tokens > self.model_limit: - break - else: - prev_accumulated_tokens = accumulated_tokens + i = messages.index_over_token_limit(self.model_limit) + + if i == 0: + raise Exception("Initial prompt is too long to summarize") - available_tokens = self.model_limit - prev_accumulated_tokens + accumulated_tokens = messages[i].usage.total_tokens + available_tokens = self.model_limit - messages[i - 1].usage.total_tokens + assert available_tokens > 0 message = messages.messages[i] - if i == len(messages) and accumulated_tokens < self.model_limit: + if (i + 1) == len(messages) and accumulated_tokens < self.model_limit: + # Summary is no longer required break - elif available_tokens >= message.token_length: + elif available_tokens >= message.usage.completion_tokens: + # Summary is required but the next message fits in the context window summary = await self.summarize(messages.messages[: (i + 1)]) - messages = [summary, *messages.messages[(i + 1) :]] + messages = ChatMessages(messages=[summary, *messages.messages[(i + 1) :]]) + messages.recompute_usage() + elif message.message.content is None: + # If the message is missing content we can't summarize it. Attempt to summarize the previous + # message and continue + summary = await self.summarize(messages.messages[:i]) + messages = ChatMessages(messages=[summary, *messages.messages[i:]]) + messages.recompute_usage() + elif message.encoding is None: + # Message is missing an encoding. This can happen if the message was created by the agent + message.encoding = self.encoder.encode(message.message.content) else: - encoding = self.encoder.encode(message.content) - + # Summary is required but the next message is too large to fit in the context window message_obj = copy.copy(message) next_message_obj = copy.copy(message) - message_obj.content = self.encoder.decode(encoding[:available_tokens]) - message_obj.token_length = len(encoding[:available_tokens]) - next_message_obj.content = self.encoder.decode(encoding[available_tokens:]) - next_message_obj.token_length = len(encoding[available_tokens:]) + message_obj.message.content = self.encoder.decode(message.encoding[:available_tokens]) + message_obj.encoding = message.encoding[:available_tokens] + message_obj.usage.completion_tokens = available_tokens - summary = await self.summarize([*messages.messages[:i], message_obj]) - messages = [summary, next_message_obj, *messages.messages[(i + 1) :]] + next_message_obj.message.content = self.encoder.decode(message.encoding[available_tokens:]) + next_message_obj.encoding = message.encoding[available_tokens:] + next_message_obj.usage.completion_tokens = len(message.encoding[available_tokens:]) - messages = ChatMessages(messages=messages) + summary = await self.summarize([*messages.messages[:i], message_obj]) + if isinstance(message, FunctionMessage): + messages = ChatMessages( + messages=[summary, messages[i - 1], next_message_obj, *messages.messages[(i + 1) :]] + ) + else: + messages = ChatMessages(messages=[summary, next_message_obj, *messages.messages[(i + 1) :]]) + messages.recompute_usage() return messages async def request(self, user_input: str) -> str: - logging.info(f"Responding to request for: {self.chat_id}") - - user_message = self.build_message(UserMessage, content=user_input) - messages = ChatMessages(messages=[*self.cached_messages, user_message]) - - final_response: str | None = None - usage = 0 - - while not final_response: - new_messages = [] - if usage > self.token_limit: - messages = await self.evaluate_summary(messages) + messages = self.cached_messages + messages.append(self.build_message(UserMessage, content=user_input, current_usage=messages[-1].usage)) + while True: response = await self.model(messages=messages.to_gpt()) - - usage = response.usage.total_tokens response_choice = response.choices[0] - response_message = response_choice.message - messages.append(response_message) + messages.append(UsageMessage(usage=response.usage, message=response_choice.message)) + if messages.current_usage.total_tokens > self.model_limit: + messages = await self.evaluate_summary(messages) if finish_reason := response_choice.finish_reason == "stop": - final_response = response_message.content + break elif finish_reason == "length": - summary = await self.summarize(messages[:-1]) - messages = ChatMessages(messages=[summary, messages[-1]]) - elif response_message.tool_calls: - new_messages = [] - for tool_call in response_message.tool_calls: + messages = await self.evaluate_summary(messages) + elif tool_calls := response_choice.message.tool_calls: + tool_request = messages[-1] + tool_request_idx = len(messages) - 1 + tool_responses = ChatMessages(messages=[]) + + # Clear some space for the tool response + if messages.current_usage.total_tokens > (self.model_limit * 0.6): + summary = await self.summarize(messages.messages[:-1]) + messages = ChatMessages(messages=[summary, tool_request]) + + message_tokens = messages.current_usage.total_tokens + for i, tool_call in enumerate(tool_calls): func_id = tool_call.function.name func_kwargs = json.loads(tool_call.function.arguments) - api = self.api_functions.get(func_id, InvalidAPI(self.api_functions.values())) + api = self.api_functions.get(func_id, self.invalid_api) response = await api.response(**func_kwargs) + message = self.build_message( + FunctionMessage, + content=response, + name=func_id, + tool_call_id=tool_call.id, + current_usage=messages[-1].usage, + args=func_kwargs, + ) + tool_responses.append(message) + + if message_tokens + tool_responses.current_usage.total_tokens > self.model_limit: + for response in tool_responses.messages: + tool_responses[-1] = await self.summarize_tool(messages[:-1], response) + + messages.extend(tool_responses.messages) + idx = messages.index_over_token_limit(self.model_limit) + if idx != (len(messages) - 1): + response_choice.message.tool_calls = response_choice.message.tool_calls[:idx] + messages[tool_request_idx] = response_choice.message + messages = ChatMessages(messages=messages.messages[:idx]) + + self.cached_messages = messages + return response_choice.message.content + + async def summarize_tool(self, messages: ChatMessages | list[UsageMessage], message: UsageMessage): + def create_question(chunk: str): + return f""" + Please summarize the following chunk of content. + The chunk is: + {chunk} + """ + + if isinstance(messages, list): + messages = ChatMessages(messages=messages) + + assert isinstance(message.message, FunctionMessage) + prompt = f""" + You've requested a tool to help you with your problem, however the response from the tool was too long + to fit in the context window. The tool response requested was {message.message.name} with arguments + {message.message.args}. Please provide a brief description of the details you're looking for which a future + agent will use to summarize the tool response. Ensure you do not actually call any tools in your response. + """ + chunk_question = self.build_message(SystemMessage, prompt, messages.current_usage) + chunk_messages = ChatMessages(messages=[*messages.messages, chunk_question]) + response = await self.model(messages=chunk_messages.to_gpt()) + while response.choices[0].message.tool_calls: + reiteration = self.build_message( + SystemMessage, "YOU MUST NOT CALL TOOLS OR FUNCTIONS", chunk_messages.current_usage + ) + chunk_messages.append(reiteration) + response = await self.model(messages=chunk_messages.to_gpt()) + + summary_context = response.choices[0].message - if isinstance(api, InvalidAPI): - message = self.build_message(SystemMessage, response) - new_messages.append(message) - else: - message = self.build_message( - FunctionMessage, content=response, name=func_id, tool_call_id=tool_call.id - ) - new_messages.append(message) + chunk_size = self.model_limit - response.usage.completion_tokens - 100 + content_chunk_iter = (self.encoder.decode(chunk) for chunk in chunker(message.encoding, chunk_size)) + content_message_iter = (create_question(summary_context) for content in content_chunk_iter) + system_message_iter = (SystemMessage(content=content) for content in content_message_iter) - usage += sum([m.token_length for m in new_messages]) - messages.extend(new_messages) - self.cached_messages = messages.messages - return final_response + callbacks = ( + self.client.chat.completions.create(model=self.model_type, messages=to_gpt([summary_context, message])) + for message in system_message_iter + ) + responses = [r.choices[0].message for r in await asyncio.gather(*[callback for callback in callbacks])] + new_message = await self.client.chat.completions.create( + model=self.model_type, messages=to_gpt([summary_context, *responses]) + ) + + message = self.build_message( + FunctionMessage, + content=new_message.choices[0].message.content, + name=message.message.name, + tool_call_id=message.message.tool_call_id, + current_usage=messages.current_usage, + args=message.message.args, + ) + return message async def get_chat_conversation( @@ -643,9 +379,12 @@ async def get_chat_conversation( ): chat_prompt = """ You are a helpful assistant with domain expertise about an organizations data and data infrastructure. + All of that context is available in a queryable graph where nodes represent individual data concepts like a database column or table. + Edges in the graph represent relationships between data such as where the data was sourced from. Before you can help the user, you need to understand the context of their request and what they are trying to accomplish. You should attempt to understand the context of the request and what the user is trying to accomplish. If a user asks about specific data like nodes and you're unable to find an answer you should attempt to find a similar node and explain why you think it's similar. + You should verify any specific data references you provide to the user actually exist in their infrastructure. Your responses must use Markdown syntax. Data Structure Notes: @@ -663,10 +402,12 @@ async def get_chat_conversation( # Todo: edge lookup is broken # EdgeLookupAPI(workspace=workspace), # FuzzyMatchNodesAPI(workspace=workspace), - EmbeddingSearchAPI(workspace=workspace, client=client), + EmbeddingSearchAPI(workspace=workspace), NHopQueryAPI(workspace=workspace), ] + # Use Embedding when enabled on workspace otherwise fuzzymatch + conversation = BaseConversation( prompt=chat_prompt, model_type=model_type, functions=functions, chat_id=str(chat_id), client=client ) diff --git a/grai-server/app/grAI/chat_types.py b/grai-server/app/grAI/chat_types.py new file mode 100644 index 000000000..3f9240624 --- /dev/null +++ b/grai-server/app/grAI/chat_types.py @@ -0,0 +1,121 @@ +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.completion_usage import CompletionUsage +from pydantic import BaseModel +from typing import Any, Literal, Union +from multimethod import multimethod + +RoleType = Union[Literal["user"], Literal["system"], Literal["assistant"]] + + +class BaseMessage(BaseModel): + role: str + content: str + + def representation(self) -> dict: + return {"role": self.role, "content": self.content} + + +class UserMessage(BaseMessage): + role: Literal["user"] = "user" + + +class SystemMessage(BaseMessage): + role: Literal["system"] = "system" + + +class AIMessage(BaseMessage): + role: Literal["assistant"] = "assistant" + + +class FunctionMessage(BaseMessage): + role: Literal["tool"] = "tool" + name: str + tool_call_id: str + args: dict + + def representation(self) -> dict: + return {"tool_call_id": self.tool_call_id, "role": self.role, "content": self.content, "name": self.name} + + +UsageMessageTypes = Union[UserMessage, SystemMessage, AIMessage, FunctionMessage, ChatCompletionMessage] + + +class ChatMessage(BaseModel): + message: UsageMessageTypes + + +class UsageMessage(BaseModel): + usage: CompletionUsage + message: UsageMessageTypes + encoding: list | None = None + + +@multimethod +def to_gpt(message: Any) -> dict | ChatCompletionMessage: + raise Exception(f"Cannot convert {type(message)} to GPT format") + + +@to_gpt.register +def list_to_gpt(messages: list) -> list[dict]: + return [to_gpt(message) for message in messages] + + +@to_gpt.register +def usage_message_to_gpt(message: UsageMessage) -> dict: + return to_gpt(message.message) + + +@to_gpt.register +def dict_to_gpt(message: dict) -> dict: + return message + + +@to_gpt.register +def base_message_to_gpt(message: BaseMessage) -> dict: + return message.representation() + + +@to_gpt.register +def chat_completion_message_to_gpt(message: ChatCompletionMessage) -> ChatCompletionMessage: + return message + + +class ChatMessages(BaseModel): + messages: list[UsageMessage] + + def to_gpt(self) -> list[dict]: + return to_gpt(self.messages) + + def __getitem__(self, index): + return self.messages[index] + + def __len__(self) -> int: + return len(self.messages) + + def __setitem__(self, key, value): + self.messages[key] = value + self.recompute_usage(key) + + def append(self, item): + self.messages.append(item) + self.recompute_usage(len(self.messages) - 1) + + def extend(self, items): + self.messages.extend(items) + self.recompute_usage(len(self.messages) - len(items)) + + @property + def current_usage(self) -> CompletionUsage: + return self.messages[-1].usage + + def recompute_usage(self, from_index: int = 0): + usage = self.messages[from_index].usage + for message in self.messages[from_index + 1 :]: + message.usage.prompt_tokens = usage.total_tokens + message.usage.total_tokens = usage.total_tokens + message.usage.completion_tokens + + def index_over_token_limit(self, token_limit) -> int: + for index, message in enumerate(self.messages): + if message.usage.total_tokens > token_limit: + return index + return index diff --git a/grai-server/app/grAI/encoders.py b/grai-server/app/grAI/encoders.py index d9be5a224..9bdcee251 100644 --- a/grai-server/app/grAI/encoders.py +++ b/grai-server/app/grAI/encoders.py @@ -1,17 +1,22 @@ import tiktoken import openai from typing import TypeVar +from django.conf import settings + R = TypeVar("R") class OpenAIEmbedder: - def __init__(self, model: str, context_window: int): + def __init__(self, model: str, context_window: int, client: openai.AsyncOpenAI | None = None): self.model = model self.model_context_window = context_window self.encoder = tiktoken.encoding_for_model(self.model) + if client is None: + client = openai.AsyncOpenAI(api_key=settings.OPENAI_API_KEY, organization=settings.OPENAI_ORG_ID) + self.client: openai.AsyncOpenAI = client - self.heuristic_max_length = self.model_context_window * 4 * 0.85 + self.heuristic_max_length = int(self.model_context_window * 4 * 0.85) def get_encoding(self, content: str) -> list[int]: return self.encoder.encode(content) @@ -20,8 +25,6 @@ def decode(self, encoding: list[int]) -> str: return self.encoder.decode(encoding) def get_max_length_content(self, content: str) -> str: - content_length = len(content) - # Heuristic estimate of the max length of content that can be encoded if len(content) < self.heuristic_max_length: return content @@ -32,6 +35,9 @@ def get_max_length_content(self, content: str) -> str: else: return self.decode(encoded[: self.model_context_window]) - def get_embedding(self, content: str) -> R: + async def get_embedding(self, content: str) -> R: content = self.get_max_length_content(content) - return openai.embedding.create(input=content, model=self.model) + return await self.client.embeddings.create(input=content, model=self.model) + + +Embedder = OpenAIEmbedder("text-embedding-ada-002", 8100) diff --git a/grai-server/app/grAI/mocks.py b/grai-server/app/grAI/mocks.py new file mode 100644 index 000000000..2ed2f42e4 --- /dev/null +++ b/grai-server/app/grAI/mocks.py @@ -0,0 +1,3 @@ +class FakeEncoder: + def encode(self, text): + return [1, 2, 3, 4] diff --git a/grai-server/app/grAI/tools.py b/grai-server/app/grAI/tools.py new file mode 100644 index 000000000..cea97dc8e --- /dev/null +++ b/grai-server/app/grAI/tools.py @@ -0,0 +1,340 @@ +import operator +import uuid +from abc import ABC, abstractmethod +from functools import reduce +import logging + +from lineage.models import Edge, Node, NodeEmbeddings +from pydantic import BaseModel, Field +from grai_schemas.serializers import GraiYamlSerializer +from django.db.models import Q +from channels.db import database_sync_to_async +from pgvector.django import MaxInnerProduct + +from typing import Annotated, Any, Callable, Literal, ParamSpec, Type, TypeVar, Union +from connections.adapters.schemas import model_to_schema +import tiktoken +import openai +from grai_schemas.v1.edge import EdgeV1 +from grAI.encoders import Embedder + +T = TypeVar("T") +R = TypeVar("R") +P = ParamSpec("P") + +RoleType = Union[Literal["user"], Literal["system"], Literal["assistant"]] + + +def filter_node_content(node: "Node") -> dict: + from connections.adapters.schemas import model_to_schema + + spec_keys = ["name", "namespace", "metadata", "data_sources"] + + result: dict = model_to_schema(node, "NodeV1").spec.dict() + result = {key: result[key] for key in spec_keys} + result["metadata"] = result["metadata"]["grai"] + return result + + +class API(ABC): + schema_model: BaseModel + description: str + id: str + + @abstractmethod + def call(self, **kwargs) -> (Any, str): + pass + + def serialize(self, result) -> str: + if isinstance(result, str): + return result + + return GraiYamlSerializer.dump(result) + + async def response(self, **kwargs) -> str: + obj, message = await self.call(**kwargs) + + if message is None: + result = self.serialize(obj) + else: + result = f"{self.serialize(obj)}\n{message}" + + return result + + def gpt_definition(self) -> dict: + return { + "type": "function", + "function": {"name": self.id, "description": self.description, "parameters": self.schema_model.schema()}, + } + + +class NodeIdentifier(BaseModel): + name: str = Field(description="The name of the node to query for") + namespace: str = Field(description="The namespace of the node to query for") + + +class NodeLookup(BaseModel): + nodes: list[NodeIdentifier] = Field(description="A list of nodes to lookup") + + +class NodeLookupAPI(API): + id = "node_lookup" + description = "Lookup metadata about one or more nodes if you know precisely which node(s) to lookup" + schema_model = NodeLookup + + def __init__(self, workspace: str | uuid.UUID): + self.workspace = workspace + + @staticmethod + def response_message(result_set: list[Node]) -> str | None: + total_results = len(result_set) + if total_results == 0: + message = "No results found matching these query conditions." + else: + message = None + + return message + + @database_sync_to_async + def call(self, **kwargs) -> (list[Node], str | None): + try: + validation = self.schema_model(**kwargs) + except Exception as e: + return [], f"Invalid input. {e}" + q_objects = (Q(**node.dict(exclude_none=True)) for node in validation.nodes) + query = reduce(operator.or_, q_objects, Q()) + result_set = ( + Node.objects.filter(workspace=self.workspace) + .filter(query) + .prefetch_related("data_sources") + .order_by("-created_at") + .all() + ) + response_items = model_to_schema(result_set, "NodeV1") + return response_items, self.response_message(result_set) + + +class FuzzyMatchQuery(BaseModel): + string: str = Field(description="The fuzzy string used to search amongst node names") + + +class FuzzyMatchNodesAPI(API): + id = "node_fuzzy_lookup" + description = "Performs a fuzzy search for nodes matching a name regardless of namespace" + schema_model = FuzzyMatchQuery + + def __init__(self, workspace: str | uuid.UUID): + self.workspace = workspace + + @staticmethod + def response_message(result_set: list[Node]) -> str | None: + total_results = len(result_set) + if total_results == 0: + message = "No results found matching these query conditions." + else: + message = None + + return message + + @database_sync_to_async + def call(self, string: str) -> (list, str | None): + result_set = ( + Node.objects.prefetch_related("data_sources") + .filter(workspace=self.workspace) + .filter(name__contains=string) + .order_by("-created_at") + .all() + ) + response_items = [{"name": node.name, "namespace": node.namespace} for node in result_set] + + return response_items, self.response_message(result_set) + + +class EdgeLookupSchema(BaseModel): + source: NodeIdentifier = Field(description="The primary key of the source node on an edge") + destination: NodeIdentifier = Field(description="The primary key of the destination node on an edge") + + +class EdgeLookupAPI(API): + id = "edge_lookup" + description = """ + This function Supports looking up edges from a data lineage graph. For example, a query with name=Test but no + namespace value will return all edges explicitly named "Test" regardless of namespace. + Edges are uniquely identified both by their (name, namespace), and by the (source, destination) nodes they connect. + """ + schema_model = EdgeLookupSchema + + def __init__(self, workspace: str | uuid.UUID): + self.workspace = workspace + + @staticmethod + def response_message(result_set: list[Edge]) -> str | None: + total_results = len(result_set) + if total_results == 0: + message = "No results found matching these query conditions." + else: + message = None + + return message + + @database_sync_to_async + def call(self, **kwargs) -> tuple[list[EdgeV1], str | None]: + try: + validation = self.schema_model(**kwargs) + except Exception as e: + return [], f"Invalid input. {e}" + + query = Q() + query &= Q(source__name=validation.source.name, source__namespace=validation.source.namespace) + query &= Q( + destination__name=validation.destination.name, destination__namespace=validation.destination.namespace + ) + + results = Edge.objects.filter(workspace=self.workspace).filter(query).all() + return model_to_schema(results, "EdgeV1"), self.response_message(results) + + +class NHopQuerySchema(BaseModel): + name: str = Field(description="The name of the node to query for") + namespace: str = Field(description="The namespace of the node to query for") + n: int = Field(description="The number of hops to query for", default=1) + + +class NHopQueryAPI(API): + id: str = "n_hop_query" + description: str = "query for nodes and edges within a specified number of hops from a given node" + schema_model = NHopQuerySchema + + def __init__(self, workspace: str | uuid.UUID): + self.workspace = workspace + + @staticmethod + def response_message(result_set: list[str]) -> str | None: + total_results = len(result_set) + if total_results == 0: + message = "No results found matching these query conditions." + else: + message = ( + "Results are returned in the following format:\n" + "(source name, source namespace) - (edge type) -> (destination name, destination namespace)\n" + "(source2 name, source2 namespace) - (edge type) -> (destination2 name, destination2 namespace)\n..." + ) + + return message + + @staticmethod + def filter(queryset: list[Edge], source_nodes: list[Node], dest_nodes: list[Node]) -> tuple[list[Node], list[Node]]: + def get_id(node: Node) -> tuple[str, str]: + return node.name, node.namespace + + source_ids: set[T] = {get_id(node) for node in source_nodes} + dest_ids: set[T] = {get_id(node) for node in dest_nodes} + query_hashes: set[tuple[T, T]] = {(get_id(node.source), get_id(node.destination)) for node in queryset} + + source_resp = [n.destination for n, hashes in zip(queryset, query_hashes) if hashes[0] in source_ids] + dest_resp = [n.source for n, hashes in zip(queryset, query_hashes) if hashes[1] in dest_ids] + return source_resp, dest_resp + + @database_sync_to_async + def call(self, **kwargs) -> (str, str | None): + def edge_label(edge: Edge): + edge_type = edge.metadata["grai"]["edge_type"] + return f"({edge.source.name}, {edge.source.namespace}) - {edge_type} -> ({edge.destination.name}, {edge.destination.namespace})" + + try: + inp = self.schema_model(**kwargs) + except Exception as e: + return [], f"Invalid input: {e}" + + source_node = Node.objects.filter(workspace__id=self.workspace, name=inp.name, namespace=inp.namespace).first() + if source_node is None: + return [], self.response_message([]) + + source_nodes = [source_node] + dest_nodes = [source_node] + + return_edges = [] + for i in range(inp.n): + query = Q(source__in=source_nodes) | Q(destination__in=dest_nodes) + edges = Edge.objects.filter(query).select_related("source", "destination").all() + + source_nodes, dest_nodes = self.filter(edges, source_nodes, dest_nodes) + return_edges.extend((edge_label(item) for item in edges)) + + if len(source_nodes) == 0 and len(dest_nodes) == 0: + break + + return_str = "\n".join(return_edges) + return return_str, self.response_message(return_edges) + + +class EmbeddingSearchSchema(BaseModel): + search_term: str = Field(description="Search request string") + limit: int = Field(description="number of results to return", default=10) + + +class EmbeddingSearchAPI(API): + id: str = "embedding_search_api" + description: str = ( + "Search for nodes matching any query. Results are returned in the following format:\n" + "(node1 name, node1 namespace, node1 type)\n(node2 name, node2 namespace, node2 type)\n..." + ) + schema_model = EmbeddingSearchSchema + + def __init__(self, workspace: str | uuid.UUID): + self.workspace = workspace + + @staticmethod + def response_message(result_set: list) -> str | None: + total_results = len(result_set) + if total_results == 0: + message = "No results found matching these query conditions." + else: + message = None + + return message + + @database_sync_to_async + def nearest_neighbor_search(self, vector_query: list[int], limit=10) -> list[Node]: + node_result = ( + NodeEmbeddings.objects.filter(node__workspace__id=self.workspace) + .order_by(MaxInnerProduct("embedding", vector_query)) + .select_related("node")[:limit] + ) + + return [n.node for n in node_result] + + async def call(self, **kwargs) -> tuple[str, str | None]: + try: + inp = self.schema_model(**kwargs) + except Exception as e: + return f"Invalid input. {e}", None + + embedding_resp = await Embedder.get_embedding(inp.search_term) + embedding = list(embedding_resp.data[0].embedding) + + neighbors = await self.nearest_neighbor_search(embedding, inp.limit) + response = ((n.name, n.namespace, n.metadata["grai"]["node_type"]) for n in neighbors) + response_str = "\n".join([", ".join(vals) for vals in response]) + + return response_str, self.response_message(neighbors) + + +class InvalidApiSchema(BaseModel): + pass + + +class InvalidAPI(API): + id = "invalid_api_endpoint" + description = "placeholder" + schema_model = InvalidApiSchema + + def __init__(self, apis): + self.function_string = ", ".join([f"`{api.id}`" for api in apis]) + + @database_sync_to_async + def call(self, *args, **kwargs) -> tuple[str, str | None]: + return ( + f"Invalid API Endpoint. That function does not exist. The supported apis are {self.function_string}", + None, + ) diff --git a/grai-server/app/lineage/tasks.py b/grai-server/app/lineage/tasks.py index 4b20d89f9..c5aba91b0 100644 --- a/grai-server/app/lineage/tasks.py +++ b/grai-server/app/lineage/tasks.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import TYPE_CHECKING, TypeVar from uuid import UUID +import asyncio import openai from django.core.cache import cache @@ -9,15 +10,13 @@ from grai_schemas.serializers import GraiYamlSerializer from celery import shared_task -from grAI.encoders import OpenAIEmbedder +from grAI.encoders import Embedder + T = TypeVar("T") R = TypeVar("R") -Embedder = OpenAIEmbedder("text-embedding-ada-002", 8100) - - if TYPE_CHECKING: from lineage.models import Node @@ -42,7 +41,7 @@ def create_node_vector_index(node: "Node"): from lineage.models import NodeEmbeddings content = get_embedded_node_content(node) - embedding_resp = Embedder.get_embedding(content) + embedding_resp = asyncio.run(Embedder.get_embedding(content)) NodeEmbeddings.objects.update_or_create(node=node, embedding=embedding_resp.data[0].embedding) @@ -69,7 +68,7 @@ def update_node_vector_index(self, node_id: UUID, task_id: UUID | None = None): node = Node.objects.prefetch_related("data_sources").get(id=node_id) try: create_node_vector_index(node) - except openai.error.RateLimitError: + except openai.RateLimitError: logging.info(f"Openai rate limit reach retrying in 10 seconds") self.retry(countdown=10) return