Skip to content

Commit

Permalink
Chat suggestions (#790)
Browse files Browse the repository at this point in the history
* Add welcome message to new chats

* Fix for agent const issue

* Better websocket on message handling

* Add choices
  • Loading branch information
edlouth authored Nov 9, 2023
1 parent 598f35e commit 01146a1
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 20 deletions.
27 changes: 27 additions & 0 deletions grai-frontend/src/components/chat/ChatChoices.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import React from "react"
import { Button, Grid } from "@mui/material"

type ChatChoicesProps = {
choices: string[]
onInput: (message: string) => void
}

const ChatChoices: React.FC<ChatChoicesProps> = ({ choices, onInput }) => (
<Grid container spacing={2} sx={{ mb: 2 }}>
<Grid item xs={1} />
<Grid item xs={10}>
{choices.map(choice => (
<Button
key={choice}
variant="contained"
onClick={() => onInput(choice)}
sx={{ mb: 2, borderRadius: 4 }}
>
{choice}
</Button>
))}
</Grid>
</Grid>
)

export default ChatChoices
6 changes: 4 additions & 2 deletions grai-frontend/src/components/chat/ChatHistory.test.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { render, screen } from "testing"
import ChatHistory from "./ChatHistory"

const onInput = jest.fn()

test("renders", async () => {
const messages = [
{
Expand All @@ -21,12 +23,12 @@ test("renders", async () => {
},
]

render(<ChatHistory messages={messages} />)
render(<ChatHistory messages={messages} choices={[]} onInput={onInput} />)

expect(screen.getByText("first message")).toBeInTheDocument()
expect(screen.getByText("second message")).toBeInTheDocument()
})

test("renders empty", async () => {
render(<ChatHistory messages={[]} />)
render(<ChatHistory messages={[]} choices={[]} onInput={onInput} />)
})
10 changes: 9 additions & 1 deletion grai-frontend/src/components/chat/ChatHistory.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import React from "react"
import { Box } from "@mui/material"
import ChatChoices from "./ChatChoices"
import ChatMessage, { GroupedChats } from "./ChatMessage"
import { Message } from "./ChatWindow"

Expand All @@ -16,13 +17,20 @@ const combineMessages = (messages: Message[]) =>

type ChatHistoryProps = {
messages: Message[]
choices: string[]
onInput: (message: string) => void
}

const ChatHistory: React.FC<ChatHistoryProps> = ({ messages }) => (
const ChatHistory: React.FC<ChatHistoryProps> = ({
messages,
choices,
onInput,
}) => (
<Box sx={{ flexGrow: 1, overflow: "auto", height: "200px" }}>
{combineMessages(messages).map((groupedChat, i) => (
<ChatMessage key={i} groupedChat={groupedChat} />
))}
<ChatChoices choices={choices} onInput={onInput} />
</Box>
)

Expand Down
1 change: 0 additions & 1 deletion grai-frontend/src/components/chat/ChatMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ const ChatMessage: React.FC<ChatMessageProps> = ({ groupedChat }) => {
}}
>
<Typography
align={"left"}
sx={{
px: 2,
py: 1,
Expand Down
18 changes: 16 additions & 2 deletions grai-frontend/src/components/chat/ChatWindow.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,29 @@ import ChatWindow from "./ChatWindow"
const handleInput = jest.fn()

test("renders", async () => {
render(<ChatWindow messages={[]} onInput={handleInput} workspaceId="1" />)
render(
<ChatWindow
messages={[]}
onInput={handleInput}
workspaceId="1"
choices={[]}
/>,
)

expect(screen.getByRole("textbox")).toBeInTheDocument()
})

test("type", async () => {
const user = userEvent.setup()

render(<ChatWindow messages={[]} onInput={handleInput} workspaceId="1" />)
render(
<ChatWindow
messages={[]}
onInput={handleInput}
workspaceId="1"
choices={[]}
/>,
)

expect(screen.getByRole("textbox")).toBeInTheDocument()

Expand Down
4 changes: 3 additions & 1 deletion grai-frontend/src/components/chat/ChatWindow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@ export type Message = {

type ChatWindowProps = {
messages: Message[]
choices: string[]
onInput: (message: string) => void
workspaceId: string
}

const ChatWindow: React.FC<ChatWindowProps> = ({
messages,
choices,
onInput,
workspaceId,
}) => (
<Box sx={{ display: "flex", flexDirection: "column", height: "100%" }}>
<ChatHistory messages={messages} />
<ChatHistory messages={messages} choices={choices} onInput={onInput} />
<ResetChat workspaceId={workspaceId} />
<ChatInput onInput={onInput} />
</Box>
Expand Down
59 changes: 59 additions & 0 deletions grai-frontend/src/components/chat/WebsocketChat.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,65 @@ afterEach(() => {
WS.clean()
})

test("renders", async () => {
const chat = {
id: "1",
messages: {
data: [
{
id: "1",
message: "H",
role: "user",
created_at: "2021-04-20T00:00:00.000000Z",
},
],
},
}

render(<WebsocketChat workspace={workspace} chat={chat} />)

expect(
screen.queryByRole("button", {
name: "Is there a customer table in the prod namespace?",
}),
).not.toBeInTheDocument()
})

test("click choice", async () => {
const user = userEvent.setup()

const chat = {
id: "1",
messages: {
data: [
{
id: "1",
message: "H",
role: "system",
created_at: "2021-04-20T00:00:00.000000Z",
},
],
},
}

render(<WebsocketChat workspace={workspace} chat={chat} />)

expect(
screen.getByRole("button", {
name: "Is there a customer table in the prod namespace?",
}),
).toBeInTheDocument()

await act(
async () =>
await user.click(
screen.getByRole("button", {
name: "Is there a customer table in the prod namespace?",
}),
),
)
})

test("type", async () => {
const user = userEvent.setup()

Expand Down
29 changes: 18 additions & 11 deletions grai-frontend/src/components/chat/WebsocketChat.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React, { useCallback, useEffect } from "react"
import React, { useCallback } from "react"
import { gql, useApolloClient } from "@apollo/client"
import { baseURL } from "client"
import useWebSocket from "react-use-websocket"
Expand Down Expand Up @@ -41,7 +41,14 @@ const WebsocketChat: React.FC<WebsocketChatProps> = ({ workspace, chat }) => {

const socketUrl = `${socketURL}/ws/chat/${workspace.id}/`

const { sendJsonMessage, lastMessage } = useWebSocket(socketUrl)
const { sendJsonMessage } = useWebSocket(socketUrl, {
onMessage: event => {
addMessage({
message: JSON.parse(event.data).message,
role: "system",
})
},
})

const addMessage = useCallback(
(message: Message) =>
Expand Down Expand Up @@ -76,15 +83,6 @@ const WebsocketChat: React.FC<WebsocketChatProps> = ({ workspace, chat }) => {
[cache, chat.id],
)

useEffect(() => {
if (lastMessage) {
addMessage({
message: JSON.parse(lastMessage.data).message,
role: "system",
})
}
}, [lastMessage, addMessage])

const handleInput = (message: string) => {
const msg = {
type: "chat.message",
Expand All @@ -100,9 +98,18 @@ const WebsocketChat: React.FC<WebsocketChatProps> = ({ workspace, chat }) => {
sender: message.role === "user",
}))

const choices = messages.every(m => !m.sender)
? [
"Is there a customer table in the prod namespace?",
"Which tables have an email field?",
"Produce a list of tables that aren't used by any application",
]
: []

return (
<ChatWindow
messages={messages}
choices={choices}
onInput={handleInput}
workspaceId={workspace.id}
/>
Expand Down
9 changes: 9 additions & 0 deletions grai-server/app/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from connections.models import Connection as ConnectionModel
from connections.models import Run as RunModel
from connections.types import Connector, ConnectorFilter
from grAI.models import Message as MessageModel
from grAI.models import MessageRoles
from grAI.models import UserChat as ChatModel
from grAI.types import Chat
from installations.models import Branch as BranchModel
Expand Down Expand Up @@ -1076,6 +1078,13 @@ async def last_chat(
if not chat:
chat = await ChatModel.objects.acreate(membership=membership)

await MessageModel.objects.acreate(
chat=chat,
message="Hello, I'm the GrAI assistant. How can I help you?",
visible=True,
role=MessageRoles.AGENT.value,
)

return chat


Expand Down
17 changes: 15 additions & 2 deletions grai-server/app/grAI/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from api.common import IsAuthenticated, get_user, get_workspace
from api.pagination import DataWrapper

from .models import UserChat
from .models import Message, MessageRoles, UserChat
from .types import Chat


Expand All @@ -21,12 +21,25 @@ def _create(
info: Info,
workspaceId: strawberry.ID,
) -> DataWrapper[Chat]:
def add_message(message: str):
Message.objects.create(
chat=chat,
message=message,
visible=True,
role=MessageRoles.AGENT.value,
)

user = get_user(info)
workspace = get_workspace(info, workspaceId)

membership = workspace.memberships.get(user=user)

return UserChat.objects.create(membership=membership)
chat = UserChat.objects.create(membership=membership)

add_message("Chat restarted")
add_message("Hello, I'm the GrAI assistant. How can I help you?")

return chat

return await sync_to_async(_create)(
info,
Expand Down

0 comments on commit 01146a1

Please sign in to comment.