Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Add dataclasses support. #3

Merged
merged 2 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Here is a table of supported types:
| `list` | `st.text_area(argument_name)` that will be parsed into JSON object (for `list[str]`) and `st.multiselect(argument_name, literal_values)` for `list[Literal[str1,str2,...]]`. Only these two are supported. | `st.json(result)` | <https://docs.streamlit.io/library/api-reference/widgets/st.text_area> <https://docs.streamlit.io/library/api-reference/widgets/st.multiselect> | <https://docs.streamlit.io/library/api-reference/data/st.json> |
| `dict` | `st.text_area(argument_name)` that will be parsed into JSON object. Only `dict` annotation is supported. | `st.json(result)` | <https://docs.streamlit.io/library/api-reference/widgets/st.text_area> | <https://docs.streamlit.io/library/api-reference/data/st.json> |
| `pandas.DataFrame` | ❌ | `st.dataframe(result)` | ❌ | <https://docs.streamlit.io/library/api-reference/data/st.dataframe> |
| `dataclasses.dataclass` | All the attributes of dataclass will be treated as separate annotations (exploded). If provided, it needs to be the only annotation - callable can have only one parameter. | `st.write(result)` | ❌ | <https://docs.streamlit.io/library/api-reference/write-magic/st.write> |

If your annotation is not in the list, the output will be rendered with `st.write` by default.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "streamlit_internal_app"
version = "0.2.0"
version = "0.3.0"
description = ""
authors = ["Kacper Wojtasinski <k0wojtasinski@gmail.com>"]
readme = "README.md"
Expand Down
50 changes: 44 additions & 6 deletions streamlit_internal_app/component.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import datetime
import inspect
import json
Expand All @@ -9,6 +10,11 @@
import streamlit as st
from pandas import DataFrame

from streamlit_internal_app.exceptions import (
ComponentFormWithDataclassHasMoreThanOneFieldError,
ComponentFormWithUnsupportedAnnotationError,
)


class StreamlitHandler(logging.Handler):
def emit(self, record: logging.LogRecord):
Expand Down Expand Up @@ -47,18 +53,22 @@ def __init__(self, function: Callable):
"""
self.component = function
self.name = function.__name__.replace("__", "_").replace("_", " ").capitalize()
self.dataclass_class = None

def run(self, params: dict[str, Any]):
def run(self, params: dict[str, Any] | Any):
"""
Runs the component function with the provided parameters.

Args:
params (dict[str, Any]): The parameters to be passed to the component function.
params (dict[str, Any]) | Any: The parameter(s) to be passed to the component function.

Returns:
Any: The result of running the component function.
"""
return self.component(**params)
if isinstance(params, dict):
return self.component(**params)

return self.component(params) # handle dataclasses

@staticmethod
def _render_result(result: Any, return_type: Any) -> None:
Expand Down Expand Up @@ -171,15 +181,40 @@ def _render_element(name: str, annotation: Any) -> Any:
get_args(get_args(annotation)[0]),
placeholder=name,
)
raise ValueError(f"Unsupported type {annotation}")
raise ComponentFormWithUnsupportedAnnotationError(annotation)

def _prepare_annotations(self) -> dict[str, Any]:
"""
Prepares the annotations for the component.
It checks if the component has a schema (dataclass) and if it has other types.
If the component has a schema, it returns the schema annotations.
"""
annotations = {}
has_dataclass = False

for annotation_name, annotation in self.component.__annotations__.items():
if dataclasses.is_dataclass(annotation):
has_dataclass = True
self.dataclass_class = annotation
for field in dataclasses.fields(annotation):
annotations[field.name] = field.type
else:
annotations[annotation_name] = annotation

original_annotations = self.component.__annotations__.copy()
original_annotations.pop("return", None)

if has_dataclass is True and len(original_annotations) > 1:
raise ComponentFormWithDataclassHasMoreThanOneFieldError()
return annotations

def render(self):
"""
Renders the component form.

Displays the component name, docstring (if available), and input elements for the component's parameters.
"""
annotations = self.component.__annotations__
annotations = self._prepare_annotations()
return_type = annotations.pop("return", None)

st.write(f"# {self.name}")
Expand All @@ -198,7 +233,10 @@ def render(self):
with st.spinner("Running..."):
st.write("## Results")
try:
result = self.run(results)
if self.dataclass_class is not None:
result = self.run(self.dataclass_class(**results))
else:
result = self.run(results)
self._render_result(result, return_type)
except Exception as e:
st.exception(e)
Expand Down
29 changes: 29 additions & 0 deletions streamlit_internal_app/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
class ComponentFormWithUnsupportedAnnotationError(Exception):
"""
Exception raised when an unsupported annotation is encountered in ComponentForm.

Attributes:
annotation (type): The unsupported annotation that caused the error.
"""

def __init__(self, annotation: type) -> None:
"""
Initializes a new instance of the ComponentFormWithUnsupportedAnnotationError class.

Args:
annotation (type): The unsupported annotation that caused the error.
"""
super().__init__(
f"Annotation {annotation} is not supported by ComponentForm.",
)


class ComponentFormWithDataclassHasMoreThanOneFieldError(Exception):
"""
Exception raised when a ComponentForm has both dataclass and other fields at the same time.
"""

def __init__(self) -> None:
super().__init__(
"ComponentForm cannot have dataclass and other fields at the same time.",
)
25 changes: 25 additions & 0 deletions tests/static/example_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import pathlib
import random
from dataclasses import dataclass
from typing import Literal, Union

import pandas as pd
Expand All @@ -14,6 +15,12 @@
logger = logging.getLogger(__name__)


@dataclass
class Event:
name: str
date: datetime.date


app = StreamlitInternalApp(
title="Example App",
description="This is my app",
Expand Down Expand Up @@ -214,4 +221,22 @@ def get_current_datetime() -> datetime.datetime:
return datetime.datetime.today()


@app.component
def get_time_until_event(event: Event) -> str:
"""
Get the time until the specified event.

Args:
event (Event): An instance of the Event class representing the event.

Returns:
str: A string indicating the time until the event. If the event has already occurred, it returns "Event [event name] has already occurred.". Otherwise, it returns "Time until event [event name]: [number of days] days."

"""
results = event.date - datetime.date.today()
if results.days < 0:
return f"Event {event.name} has already occurred."
return f"Time until event {event.name}: {results.days} days."


app.render()
24 changes: 24 additions & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,27 @@ def test_app_renders_get_current_time_as_expected(example_app: AppTest):
text = [x.value for x in example_app.markdown]
date_time = datetime.datetime.fromisoformat(text[-2].strip("`"))
assert date_time.date() == datetime.datetime.now().date()


def test_app_renders_get_time_until_event_as_expected(example_app: AppTest):
yesterday = datetime.datetime.today() - datetime.timedelta(days=1)
example_app.sidebar.radio[0].set_value("Get time until event").run()
assert len(example_app.date_input) == 1
assert len(example_app.text_input) == 1
example_app.text_input[0].set_value("TestEvent").run()
example_app.date_input[0].set_value(
datetime.datetime(
year=yesterday.year,
month=yesterday.month,
day=yesterday.day,
),
).run()
example_app.button[0].click().run()
text = [x.value for x in example_app.markdown]
assert text == [
"## Source code of get_time_until_event",
"# Get time until event",
"## Results",
"Event TestEvent has already occurred.",
"This is my app",
]
29 changes: 29 additions & 0 deletions tests/test_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from dataclasses import dataclass

import pytest

from streamlit_internal_app.component import ComponentForm
from streamlit_internal_app.exceptions import (
ComponentFormWithDataclassHasMoreThanOneFieldError,
ComponentFormWithUnsupportedAnnotationError,
)


def test_component_form_with_unsupported_type_should_raise_exception():
def faulty_function(data: bytes):
...

with pytest.raises(ComponentFormWithUnsupportedAnnotationError):
ComponentForm(faulty_function).render()


def test_component_form_with_dataclass_and_other_fields_should_raise_exception():
@dataclass
class DataClass:
a: int

def faulty_function(data: DataClass, other: int):
...

with pytest.raises(ComponentFormWithDataclassHasMoreThanOneFieldError):
ComponentForm(faulty_function).render()
Loading