Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat suggestions #790

Merged
merged 4 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions grai-frontend/src/components/chat/ChatChoices.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import React from "react"
import { Button, Grid } from "@mui/material"

type ChatChoicesProps = {
choices: string[]
onInput: (message: string) => void
}

const ChatChoices: React.FC<ChatChoicesProps> = ({ choices, onInput }) => (
<Grid container spacing={2} sx={{ mb: 2 }}>
<Grid item xs={1} />
<Grid item xs={10}>
{choices.map(choice => (
<Button
key={choice}
variant="contained"
onClick={() => onInput(choice)}
sx={{ mb: 2, borderRadius: 4 }}
>
{choice}
</Button>
))}
</Grid>
</Grid>
)

export default ChatChoices
6 changes: 4 additions & 2 deletions grai-frontend/src/components/chat/ChatHistory.test.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { render, screen } from "testing"
import ChatHistory from "./ChatHistory"

const onInput = jest.fn()

test("renders", async () => {
const messages = [
{
Expand All @@ -21,12 +23,12 @@ test("renders", async () => {
},
]

render(<ChatHistory messages={messages} />)
render(<ChatHistory messages={messages} choices={[]} onInput={onInput} />)

expect(screen.getByText("first message")).toBeInTheDocument()
expect(screen.getByText("second message")).toBeInTheDocument()
})

test("renders empty", async () => {
render(<ChatHistory messages={[]} />)
render(<ChatHistory messages={[]} choices={[]} onInput={onInput} />)
})
10 changes: 9 additions & 1 deletion grai-frontend/src/components/chat/ChatHistory.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import React from "react"
import { Box } from "@mui/material"
import ChatChoices from "./ChatChoices"
import ChatMessage, { GroupedChats } from "./ChatMessage"
import { Message } from "./ChatWindow"

Expand All @@ -16,13 +17,20 @@ const combineMessages = (messages: Message[]) =>

type ChatHistoryProps = {
messages: Message[]
choices: string[]
onInput: (message: string) => void
}

const ChatHistory: React.FC<ChatHistoryProps> = ({ messages }) => (
const ChatHistory: React.FC<ChatHistoryProps> = ({
messages,
choices,
onInput,
}) => (
<Box sx={{ flexGrow: 1, overflow: "auto", height: "200px" }}>
{combineMessages(messages).map((groupedChat, i) => (
<ChatMessage key={i} groupedChat={groupedChat} />
))}
<ChatChoices choices={choices} onInput={onInput} />
</Box>
)

Expand Down
1 change: 0 additions & 1 deletion grai-frontend/src/components/chat/ChatMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ const ChatMessage: React.FC<ChatMessageProps> = ({ groupedChat }) => {
}}
>
<Typography
align={"left"}
sx={{
px: 2,
py: 1,
Expand Down
18 changes: 16 additions & 2 deletions grai-frontend/src/components/chat/ChatWindow.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,29 @@ import ChatWindow from "./ChatWindow"
const handleInput = jest.fn()

test("renders", async () => {
render(<ChatWindow messages={[]} onInput={handleInput} workspaceId="1" />)
render(
<ChatWindow
messages={[]}
onInput={handleInput}
workspaceId="1"
choices={[]}
/>,
)

expect(screen.getByRole("textbox")).toBeInTheDocument()
})

test("type", async () => {
const user = userEvent.setup()

render(<ChatWindow messages={[]} onInput={handleInput} workspaceId="1" />)
render(
<ChatWindow
messages={[]}
onInput={handleInput}
workspaceId="1"
choices={[]}
/>,
)

expect(screen.getByRole("textbox")).toBeInTheDocument()

Expand Down
4 changes: 3 additions & 1 deletion grai-frontend/src/components/chat/ChatWindow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@ export type Message = {

type ChatWindowProps = {
messages: Message[]
choices: string[]
onInput: (message: string) => void
workspaceId: string
}

const ChatWindow: React.FC<ChatWindowProps> = ({
messages,
choices,
onInput,
workspaceId,
}) => (
<Box sx={{ display: "flex", flexDirection: "column", height: "100%" }}>
<ChatHistory messages={messages} />
<ChatHistory messages={messages} choices={choices} onInput={onInput} />
<ResetChat workspaceId={workspaceId} />
<ChatInput onInput={onInput} />
</Box>
Expand Down
59 changes: 59 additions & 0 deletions grai-frontend/src/components/chat/WebsocketChat.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,65 @@ afterEach(() => {
WS.clean()
})

test("renders", async () => {
const chat = {
id: "1",
messages: {
data: [
{
id: "1",
message: "H",
role: "user",
created_at: "2021-04-20T00:00:00.000000Z",
},
],
},
}

render(<WebsocketChat workspace={workspace} chat={chat} />)

expect(
screen.queryByRole("button", {
name: "Is there a customer table in the prod namespace?",
}),
).not.toBeInTheDocument()
})

test("click choice", async () => {
const user = userEvent.setup()

const chat = {
id: "1",
messages: {
data: [
{
id: "1",
message: "H",
role: "system",
created_at: "2021-04-20T00:00:00.000000Z",
},
],
},
}

render(<WebsocketChat workspace={workspace} chat={chat} />)

expect(
screen.getByRole("button", {
name: "Is there a customer table in the prod namespace?",
}),
).toBeInTheDocument()

await act(
async () =>
await user.click(
screen.getByRole("button", {
name: "Is there a customer table in the prod namespace?",
}),
),
)
})

test("type", async () => {
const user = userEvent.setup()

Expand Down
29 changes: 18 additions & 11 deletions grai-frontend/src/components/chat/WebsocketChat.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React, { useCallback, useEffect } from "react"
import React, { useCallback } from "react"
import { gql, useApolloClient } from "@apollo/client"
import { baseURL } from "client"
import useWebSocket from "react-use-websocket"
Expand Down Expand Up @@ -41,7 +41,14 @@ const WebsocketChat: React.FC<WebsocketChatProps> = ({ workspace, chat }) => {

const socketUrl = `${socketURL}/ws/chat/${workspace.id}/`

const { sendJsonMessage, lastMessage } = useWebSocket(socketUrl)
const { sendJsonMessage } = useWebSocket(socketUrl, {
onMessage: event => {
addMessage({
message: JSON.parse(event.data).message,
role: "system",
})
},
})

const addMessage = useCallback(
(message: Message) =>
Expand Down Expand Up @@ -76,15 +83,6 @@ const WebsocketChat: React.FC<WebsocketChatProps> = ({ workspace, chat }) => {
[cache, chat.id],
)

useEffect(() => {
if (lastMessage) {
addMessage({
message: JSON.parse(lastMessage.data).message,
role: "system",
})
}
}, [lastMessage, addMessage])

const handleInput = (message: string) => {
const msg = {
type: "chat.message",
Expand All @@ -100,9 +98,18 @@ const WebsocketChat: React.FC<WebsocketChatProps> = ({ workspace, chat }) => {
sender: message.role === "user",
}))

const choices = messages.every(m => !m.sender)
? [
"Is there a customer table in the prod namespace?",
"Which tables have an email field?",
"Produce a list of tables that aren't used by any application",
]
: []

return (
<ChatWindow
messages={messages}
choices={choices}
onInput={handleInput}
workspaceId={workspace.id}
/>
Expand Down
9 changes: 9 additions & 0 deletions grai-server/app/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from connections.models import Connection as ConnectionModel
from connections.models import Run as RunModel
from connections.types import Connector, ConnectorFilter
from grAI.models import Message as MessageModel
from grAI.models import MessageRoles
from grAI.models import UserChat as ChatModel
from grAI.types import Chat
from installations.models import Branch as BranchModel
Expand Down Expand Up @@ -1076,6 +1078,13 @@ async def last_chat(
if not chat:
chat = await ChatModel.objects.acreate(membership=membership)

await MessageModel.objects.acreate(
chat=chat,
message="Hello, I'm the GrAI assistant. How can I help you?",
visible=True,
role=MessageRoles.AGENT.value,
)

return chat


Expand Down
17 changes: 15 additions & 2 deletions grai-server/app/grAI/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from api.common import IsAuthenticated, get_user, get_workspace
from api.pagination import DataWrapper

from .models import UserChat
from .models import Message, MessageRoles, UserChat
from .types import Chat


Expand All @@ -21,12 +21,25 @@ def _create(
info: Info,
workspaceId: strawberry.ID,
) -> DataWrapper[Chat]:
def add_message(message: str):
Message.objects.create(
chat=chat,
message=message,
visible=True,
role=MessageRoles.AGENT.value,
)

user = get_user(info)
workspace = get_workspace(info, workspaceId)

membership = workspace.memberships.get(user=user)

return UserChat.objects.create(membership=membership)
chat = UserChat.objects.create(membership=membership)

add_message("Chat restarted")
add_message("Hello, I'm the GrAI assistant. How can I help you?")

return chat

return await sync_to_async(_create)(
info,
Expand Down
Loading