diff --git a/grai-frontend/src/components/chat/ChatChoices.tsx b/grai-frontend/src/components/chat/ChatChoices.tsx new file mode 100644 index 000000000..e64ed0330 --- /dev/null +++ b/grai-frontend/src/components/chat/ChatChoices.tsx @@ -0,0 +1,27 @@ +import React from "react" +import { Button, Grid } from "@mui/material" + +type ChatChoicesProps = { + choices: string[] + onInput: (message: string) => void +} + +const ChatChoices: React.FC = ({ choices, onInput }) => ( + + + + {choices.map(choice => ( + + ))} + + +) + +export default ChatChoices diff --git a/grai-frontend/src/components/chat/ChatHistory.test.tsx b/grai-frontend/src/components/chat/ChatHistory.test.tsx index 6aca0661a..197c495a2 100644 --- a/grai-frontend/src/components/chat/ChatHistory.test.tsx +++ b/grai-frontend/src/components/chat/ChatHistory.test.tsx @@ -1,6 +1,8 @@ import { render, screen } from "testing" import ChatHistory from "./ChatHistory" +const onInput = jest.fn() + test("renders", async () => { const messages = [ { @@ -21,12 +23,12 @@ test("renders", async () => { }, ] - render() + render() expect(screen.getByText("first message")).toBeInTheDocument() expect(screen.getByText("second message")).toBeInTheDocument() }) test("renders empty", async () => { - render() + render() }) diff --git a/grai-frontend/src/components/chat/ChatHistory.tsx b/grai-frontend/src/components/chat/ChatHistory.tsx index c3b53e4ff..1b5c3ae08 100644 --- a/grai-frontend/src/components/chat/ChatHistory.tsx +++ b/grai-frontend/src/components/chat/ChatHistory.tsx @@ -1,5 +1,6 @@ import React from "react" import { Box } from "@mui/material" +import ChatChoices from "./ChatChoices" import ChatMessage, { GroupedChats } from "./ChatMessage" import { Message } from "./ChatWindow" @@ -16,13 +17,20 @@ const combineMessages = (messages: Message[]) => type ChatHistoryProps = { messages: Message[] + choices: string[] + onInput: (message: string) => void } -const ChatHistory: React.FC = ({ messages }) => ( +const ChatHistory: React.FC = ({ + messages, + choices, + onInput, +}) => ( {combineMessages(messages).map((groupedChat, i) => ( ))} + ) diff --git a/grai-frontend/src/components/chat/ChatMessage.tsx b/grai-frontend/src/components/chat/ChatMessage.tsx index 5a87c9e56..1793c617a 100644 --- a/grai-frontend/src/components/chat/ChatMessage.tsx +++ b/grai-frontend/src/components/chat/ChatMessage.tsx @@ -88,7 +88,6 @@ const ChatMessage: React.FC = ({ groupedChat }) => { }} > { - render() + render( + , + ) expect(screen.getByRole("textbox")).toBeInTheDocument() }) @@ -13,7 +20,14 @@ test("renders", async () => { test("type", async () => { const user = userEvent.setup() - render() + render( + , + ) expect(screen.getByRole("textbox")).toBeInTheDocument() diff --git a/grai-frontend/src/components/chat/ChatWindow.tsx b/grai-frontend/src/components/chat/ChatWindow.tsx index c66731f43..f1708c202 100644 --- a/grai-frontend/src/components/chat/ChatWindow.tsx +++ b/grai-frontend/src/components/chat/ChatWindow.tsx @@ -11,17 +11,19 @@ export type Message = { type ChatWindowProps = { messages: Message[] + choices: string[] onInput: (message: string) => void workspaceId: string } const ChatWindow: React.FC = ({ messages, + choices, onInput, workspaceId, }) => ( - + diff --git a/grai-frontend/src/components/chat/WebsocketChat.test.tsx b/grai-frontend/src/components/chat/WebsocketChat.test.tsx index e60d94f76..f37aeac5a 100644 --- a/grai-frontend/src/components/chat/WebsocketChat.test.tsx +++ b/grai-frontend/src/components/chat/WebsocketChat.test.tsx @@ -25,6 +25,65 @@ afterEach(() => { WS.clean() }) +test("renders", async () => { + const chat = { + id: "1", + messages: { + data: [ + { + id: "1", + message: "H", + role: "user", + created_at: "2021-04-20T00:00:00.000000Z", + }, + ], + }, + } + + render() + + expect( + screen.queryByRole("button", { + name: "Is there a customer table in the prod namespace?", + }), + ).not.toBeInTheDocument() +}) + +test("click choice", async () => { + const user = userEvent.setup() + + const chat = { + id: "1", + messages: { + data: [ + { + id: "1", + message: "H", + role: "system", + created_at: "2021-04-20T00:00:00.000000Z", + }, + ], + }, + } + + render() + + expect( + screen.getByRole("button", { + name: "Is there a customer table in the prod namespace?", + }), + ).toBeInTheDocument() + + await act( + async () => + await user.click( + screen.getByRole("button", { + name: "Is there a customer table in the prod namespace?", + }), + ), + ) +}) + test("type", async () => { const user = userEvent.setup() diff --git a/grai-frontend/src/components/chat/WebsocketChat.tsx b/grai-frontend/src/components/chat/WebsocketChat.tsx index fca707537..1fe7e3bf8 100644 --- a/grai-frontend/src/components/chat/WebsocketChat.tsx +++ b/grai-frontend/src/components/chat/WebsocketChat.tsx @@ -1,4 +1,4 @@ -import React, { useCallback, useEffect } from "react" +import React, { useCallback } from "react" import { gql, useApolloClient } from "@apollo/client" import { baseURL } from "client" import useWebSocket from "react-use-websocket" @@ -41,7 +41,14 @@ const WebsocketChat: React.FC = ({ workspace, chat }) => { const socketUrl = `${socketURL}/ws/chat/${workspace.id}/` - const { sendJsonMessage, lastMessage } = useWebSocket(socketUrl) + const { sendJsonMessage } = useWebSocket(socketUrl, { + onMessage: event => { + addMessage({ + message: JSON.parse(event.data).message, + role: "system", + }) + }, + }) const addMessage = useCallback( (message: Message) => @@ -76,15 +83,6 @@ const WebsocketChat: React.FC = ({ workspace, chat }) => { [cache, chat.id], ) - useEffect(() => { - if (lastMessage) { - addMessage({ - message: JSON.parse(lastMessage.data).message, - role: "system", - }) - } - }, [lastMessage, addMessage]) - const handleInput = (message: string) => { const msg = { type: "chat.message", @@ -100,9 +98,18 @@ const WebsocketChat: React.FC = ({ workspace, chat }) => { sender: message.role === "user", })) + const choices = messages.every(m => !m.sender) + ? [ + "Is there a customer table in the prod namespace?", + "Which tables have an email field?", + "Produce a list of tables that aren't used by any application", + ] + : [] + return ( diff --git a/grai-server/app/api/types.py b/grai-server/app/api/types.py index 76dcfc316..d4bad2e5a 100755 --- a/grai-server/app/api/types.py +++ b/grai-server/app/api/types.py @@ -21,6 +21,8 @@ from connections.models import Connection as ConnectionModel from connections.models import Run as RunModel from connections.types import Connector, ConnectorFilter +from grAI.models import Message as MessageModel +from grAI.models import MessageRoles from grAI.models import UserChat as ChatModel from grAI.types import Chat from installations.models import Branch as BranchModel @@ -1076,6 +1078,13 @@ async def last_chat( if not chat: chat = await ChatModel.objects.acreate(membership=membership) + await MessageModel.objects.acreate( + chat=chat, + message="Hello, I'm the GrAI assistant. How can I help you?", + visible=True, + role=MessageRoles.AGENT.value, + ) + return chat diff --git a/grai-server/app/grAI/mutations.py b/grai-server/app/grAI/mutations.py index c07e34ea7..ef3055132 100644 --- a/grai-server/app/grAI/mutations.py +++ b/grai-server/app/grAI/mutations.py @@ -5,7 +5,7 @@ from api.common import IsAuthenticated, get_user, get_workspace from api.pagination import DataWrapper -from .models import UserChat +from .models import Message, MessageRoles, UserChat from .types import Chat @@ -21,12 +21,25 @@ def _create( info: Info, workspaceId: strawberry.ID, ) -> DataWrapper[Chat]: + def add_message(message: str): + Message.objects.create( + chat=chat, + message=message, + visible=True, + role=MessageRoles.AGENT.value, + ) + user = get_user(info) workspace = get_workspace(info, workspaceId) membership = workspace.memberships.get(user=user) - return UserChat.objects.create(membership=membership) + chat = UserChat.objects.create(membership=membership) + + add_message("Chat restarted") + add_message("Hello, I'm the GrAI assistant. How can I help you?") + + return chat return await sync_to_async(_create)( info,