Skip to content

Commit

Permalink
Display pandas DataFrame as interactive table (#1373)
Browse files Browse the repository at this point in the history
* Add Dataframe class
* Add Dataframe component
* Add @mui/x-data-grid package
* Refactor Dataframe element to handle DataFrame serialization internally

---------

Co-authored-by: Mathijs de Bruin <mathijs@mathijsfietst.nl>
  • Loading branch information
desertproject and dokterbob authored Oct 16, 2024
1 parent 00f96b8 commit f6159d5
Show file tree
Hide file tree
Showing 17 changed files with 384 additions and 21 deletions.
1 change: 1 addition & 0 deletions backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from chainlit.element import (
Audio,
Component,
Dataframe,
File,
Image,
Pdf,
Expand Down
1 change: 0 additions & 1 deletion backend/chainlit/data/sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import aiofiles
import aiohttp

from chainlit.data.base import BaseDataLayer
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.data.utils import queue_until_user_message
Expand Down
30 changes: 29 additions & 1 deletion backend/chainlit/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,16 @@
}

ElementType = Literal[
"image", "text", "pdf", "tasklist", "audio", "video", "file", "plotly", "component"
"image",
"text",
"pdf",
"tasklist",
"audio",
"video",
"file",
"plotly",
"dataframe",
"component",
]
ElementDisplay = Literal["inline", "side", "page"]
ElementSize = Literal["small", "medium", "large"]
Expand Down Expand Up @@ -358,6 +367,25 @@ def __post_init__(self) -> None:
super().__post_init__()


@dataclass
class Dataframe(Element):
"""Useful to send a pandas DataFrame to the UI."""

type: ClassVar[ElementType] = "dataframe"
size: ElementSize = "large"
data: Any = None # The type is Any because it is checked in __post_init__.

def __post_init__(self) -> None:
"""Ensures the data is a pandas DataFrame and converts it to JSON."""
from pandas import DataFrame

if not isinstance(self.data, DataFrame):
raise TypeError("data must be a pandas.DataFrame")

self.content = self.data.to_json(orient="split", date_format="iso")
super().__post_init__()


@dataclass
class Component(Element):
"""Useful to send a custom component to the UI."""
Expand Down
28 changes: 27 additions & 1 deletion backend/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ slack_bolt = "^1.18.1"
discord = "^2.3.2"
botbuilder-core = "^4.15.0"
aiosqlite = "^0.20.0"
pandas = "^2.2.2"
moto = "^5.0.14"

[tool.poetry.group.dev.dependencies]
Expand All @@ -94,6 +95,7 @@ mypy = "^1.7.1"
types-requests = "^2.31.0.2"
types-aiofiles = "^23.1.0.5"
mypy-boto3-dynamodb = "^1.34.113"
pandas-stubs = { version = "^2.2.2", python = ">=3.9" }

[tool.mypy]
python_version = "3.9"
Expand All @@ -120,7 +122,6 @@ ignore_missing_imports = true




[tool.poetry.group.custom-data]
optional = true

Expand Down
3 changes: 1 addition & 2 deletions backend/tests/data/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest

from unittest.mock import AsyncMock

import pytest
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.user import User

Expand Down
36 changes: 23 additions & 13 deletions backend/tests/data/test_sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from pathlib import Path

import pytest
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.element import Text
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine

from chainlit import User
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.element import Text


@pytest.fixture
Expand All @@ -23,18 +23,21 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
# Ref: https://docs.chainlit.io/data-persistence/custom#sql-alchemy-data-layer
async with engine.begin() as conn:
await conn.execute(
text("""
text(
"""
CREATE TABLE users (
"id" UUID PRIMARY KEY,
"identifier" TEXT NOT NULL UNIQUE,
"metadata" JSONB NOT NULL,
"createdAt" TEXT
);
""")
"""
)
)

await conn.execute(
text("""
text(
"""
CREATE TABLE IF NOT EXISTS threads (
"id" UUID PRIMARY KEY,
"createdAt" TEXT,
Expand All @@ -45,11 +48,13 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
"metadata" JSONB,
FOREIGN KEY ("userId") REFERENCES users("id") ON DELETE CASCADE
);
""")
"""
)
)

await conn.execute(
text("""
text(
"""
CREATE TABLE IF NOT EXISTS steps (
"id" UUID PRIMARY KEY,
"name" TEXT NOT NULL,
Expand All @@ -72,11 +77,13 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
"language" TEXT,
"indent" INT
);
""")
"""
)
)

await conn.execute(
text("""
text(
"""
CREATE TABLE IF NOT EXISTS elements (
"id" UUID PRIMARY KEY,
"threadId" UUID,
Expand All @@ -92,19 +99,22 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
"forId" UUID,
"mime" TEXT
);
""")
"""
)
)

await conn.execute(
text("""
text(
"""
CREATE TABLE IF NOT EXISTS feedbacks (
"id" UUID PRIMARY KEY,
"forId" UUID NOT NULL,
"threadId" UUID NOT NULL,
"value" INT NOT NULL,
"comment" TEXT
);
""")
"""
)
)

# Create SQLAlchemyDataLayer instance
Expand Down
68 changes: 68 additions & 0 deletions cypress/e2e/dataframe/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pandas as pd

import chainlit as cl


@cl.on_chat_start
async def start():
# Create a sample DataFrame with more than 10 rows to test pagination functionality
data = {
"Name": [
"Alice",
"David",
"Charlie",
"Bob",
"Eva",
"Grace",
"Hannah",
"Jack",
"Frank",
"Kara",
"Liam",
"Ivy",
"Mia",
"Noah",
"Olivia",
],
"Age": [25, 40, 35, 30, 45, 55, 60, 70, 50, 75, 80, 65, 85, 90, 95],
"City": [
"New York",
"Houston",
"Chicago",
"Los Angeles",
"Phoenix",
"San Antonio",
"San Diego",
"San Jose",
"Philadelphia",
"Austin",
"Fort Worth",
"Dallas",
"Jacksonville",
"Columbus",
"Charlotte",
],
"Salary": [
70000,
100000,
90000,
80000,
110000,
130000,
140000,
160000,
120000,
170000,
180000,
150000,
190000,
200000,
210000,
],
}

df = pd.DataFrame(data)

elements = [cl.Dataframe(data=df, display="inline", name="Dataframe")]

await cl.Message(content="This message has a Dataframe", elements=elements).send()
41 changes: 41 additions & 0 deletions cypress/e2e/dataframe/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { runTestServer } from '../../support/testUtils';

describe('dataframe', () => {
before(() => {
runTestServer();
});

it('should be able to display an inline dataframe', () => {
// Check if the DataFrame is rendered within the first step
cy.get('.step').should('have.length', 1);
cy.get('.step').first().find('.MuiDataGrid-main').should('have.length', 1);

// Click the sort button in the "Age" column header to sort in ascending order
cy.get('.MuiDataGrid-columnHeader[aria-label="Age"]')
.find('button')
.first()
.click({ force: true });
// Verify the first row's "Age" cell contains '25' after sorting
cy.get('.MuiDataGrid-row')
.first()
.find('.MuiDataGrid-cell[data-field="Age"] .MuiDataGrid-cellContent')
.should('have.text', '25');

// Click the "Next page" button in the pagination controls
cy.get('.MuiTablePagination-actions').find('button').eq(1).click();
// Verify that the next page contains exactly 5 rows
cy.get('.MuiDataGrid-row').should('have.length', 5);

// Click the input to open the dropdown
cy.get('.MuiTablePagination-select').click();
// Select the option with the value '50' from the dropdown list
cy.get('ul.MuiMenu-list li').contains('50').click();
// Scroll to the bottom of the virtual scroller in the MUI DataGrid
cy.get('.MuiDataGrid-virtualScroller').scrollTo('bottom');
// Check that tha last name is Olivia
cy.get('.MuiDataGrid-row')
.last()
.find('.MuiDataGrid-cell[data-field="Name"] .MuiDataGrid-cellContent')
.should('have.text', 'Olivia');
});
});
1 change: 1 addition & 0 deletions frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"@mui/icons-material": "^5.14.9",
"@mui/lab": "^5.0.0-alpha.122",
"@mui/material": "^5.14.10",
"@mui/x-data-grid": "^6.20.4",
"formik": "^2.4.3",
"highlight.js": "^11.9.0",
"i18next": "^23.7.16",
Expand Down
Loading

0 comments on commit f6159d5

Please sign in to comment.