Skip to content

Commit

Permalink
Refactored 'on_invalid_data' to '_convert_unsupported_data()'
Browse files Browse the repository at this point in the history
Adapted tests accordingly
  • Loading branch information
FabienLelaquais committed Nov 17, 2024
1 parent 5b8bf15 commit f7ab246
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 199 deletions.
20 changes: 13 additions & 7 deletions taipy/gui/data/data_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,14 @@ def __init__(self, gui: "Gui") -> None:
self._register(_NumpyDataAccessor)

def _register(self, cls: t.Type[_DataAccessor]) -> None:
"""Register a new DataAccessor type."""
if not inspect.isclass(cls):
raise AttributeError("The argument of 'DataAccessors.register' should be a class")
raise AttributeError("The argument of 'DataAccessors.register()' should be a class")
if not issubclass(cls, _DataAccessor):
raise TypeError(f"Class {cls.__name__} is not a subclass of DataAccessor")
classes = cls.get_supported_classes()
if not classes:
raise TypeError(f"method {cls.__name__}.get_supported_classes returned an invalid value")
raise TypeError(f"{cls.__name__}.get_supported_classes() returned an invalid value")
# check existence
inst: t.Optional[_DataAccessor] = None
for cl in classes:
Expand All @@ -132,16 +133,21 @@ def _register(self, cls: t.Type[_DataAccessor]) -> None:
for cl in classes:
self.__access_4_type[cl] = inst # type: ignore

def _unregister(self, cls: t.Type[_DataAccessor]) -> None:
"""Unregisters a DataAccessor type."""
if cls in self.__access_4_type:
del self.__access_4_type[cls]

def __get_instance(self, value: _TaipyData) -> _DataAccessor: # type: ignore
value = value.get() if isinstance(value, _TaipyData) else value
access = self.__access_4_type.get(type(value))
if access is None:
if value is not None:
transformed_value = self.__gui._handle_invalid_data(value)
if transformed_value is not None:
transformed_access = self.__access_4_type.get(type(transformed_value))
if transformed_access is not None:
return transformed_access
converted_value = type(self.__gui)._convert_unsupported_data(value)
if converted_value is not None:
access = self.__access_4_type.get(type(converted_value))
if access is not None:
return access
_warn(f"Can't find Data Accessor for type {str(type(value))}.")
return self.__invalid_data_accessor
return access
Expand Down
79 changes: 46 additions & 33 deletions taipy/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ class Gui:

__content_providers: t.Dict[type, t.Callable[..., str]] = {}

# See set_unsupported_data_converter()
__unsupported_data_converter: t.Optional[t.Callable] = None

def __init__(
self,
page: t.Optional[t.Union[str, Page]] = None,
Expand Down Expand Up @@ -358,13 +361,6 @@ def __init__(
The returned HTML content can therefore use both the variables stored in the *state*
and the parameters provided in the call to `get_user_content_url()^`.
"""
self.on_invalid_data: t.Optional[t.Callable] = None
"""
The function that is called to transform data into a valid data type.
Invokes the callback to transform unsupported data into a valid format.
:param value: The unsupported data type encountered.
:return: Transformed data or None if no transformation is possible.
"""

# sid from client_id
self.__client_id_2_sid: t.Dict[str, t.Set[str]] = {}
Expand Down Expand Up @@ -414,21 +410,6 @@ def __init__(
for library in libraries:
Gui.add_library(library)

def _handle_invalid_data(self, value: t.Any) -> t.Optional[t.Any]:
"""
Handles unsupported data by invoking the callback if available.
:param value: The unsupported data encountered.
:return: Transformed data or None if no transformation is possible.
"""
try:
if _is_function(self.on_invalid_data):
return self.on_invalid_data(value) # type: ignore
else:
return None
except Exception as e:
_warn(f"Error transforming data: {str(e)}")
return None

@staticmethod
def add_library(library: ElementLibrary) -> None:
"""Add a custom visual element library.
Expand Down Expand Up @@ -469,7 +450,6 @@ def register_content_provider(content_type: type, content_provider: t.Callable[.
Arguments:
content_type: The type of the content that triggers the content provider.
content_provider: The function that converts content of type *type* into an HTML string.
""" # noqa: E501
if Gui.__content_providers.get(content_type):
_warn(f"The type {content_type} is already associated with a provider.")
Expand Down Expand Up @@ -1037,17 +1017,17 @@ def __upload_files(self):
complete = part == total - 1

# Extract upload path (when single file is selected, path="" does not change the path)
upload_root = os.path.abspath( self._get_config( "upload_folder", tempfile.gettempdir() ) )
upload_path = os.path.abspath( os.path.join( upload_root, os.path.dirname(path) ) )
if upload_path.startswith( upload_root ):
upload_path = Path( upload_path ).resolve()
os.makedirs( upload_path, exist_ok=True )
upload_root = os.path.abspath(self._get_config("upload_folder", tempfile.gettempdir()))
upload_path = os.path.abspath(os.path.join(upload_root, os.path.dirname(path)))
if upload_path.startswith(upload_root):
upload_path = Path(upload_path).resolve()
os.makedirs(upload_path, exist_ok=True)
# Save file into upload_path directory
file_path = _get_non_existent_file_path(upload_path, secure_filename(file.filename))
file.save( os.path.join( upload_path, (file_path.name + suffix) ) )
file.save(os.path.join(upload_path, (file_path.name + suffix)))
else:
_warn(f"upload files: Path {path} points outside of upload root.")
return("upload files: Path part points outside of upload root.", 400)
return ("upload files: Path part points outside of upload root.", 400)

if complete:
if part > 0:
Expand Down Expand Up @@ -1083,9 +1063,7 @@ def __upload_files(self):
if not _is_function(file_fn):
file_fn = _getscopeattr(self, on_upload_action)
if _is_function(file_fn):
self._call_function_with_state(
t.cast(t.Callable, file_fn), ["file_upload", {"args": [data]}]
)
self._call_function_with_state(t.cast(t.Callable, file_fn), ["file_upload", {"args": [data]}])
else:
setattr(self._bindings(), var_name, newvalue)
return ("", 200)
Expand Down Expand Up @@ -1165,6 +1143,41 @@ def __update_state_context(self, payload: dict):
for var, val in state_context.items():
self._update_var(var, val, True, forward=False)


@staticmethod
def set_unsupported_data_converter(converter: t.Optional[t.Callable[[t.Any], t.Any]]) -> None:
"""Set a custom converter for unsupported data types.
This function allows specifying a custom conversion function for data types that cannot
be serialized automatically. When Taipy GUI encounters an unsupported type in the data
being serialized, it will invoke the provided *converter* function. The returned value
from this function will then be serialized if possible.
Arguments:
converter: A function that converts a value with an unsupported data type (the only
parameter to the function) into data with a supported data type (the returned value
from the function).</br>
If set to `None`, it removes any existing converter.
"""
Gui.__unsupported_data_converter = converter

@staticmethod
def _convert_unsupported_data(value: t.Any) -> t.Optional[t.Any]:
"""
Handles unsupported data by invoking the converter if there is one.
Arguments:
value: The unsupported data encountered.
Returns:
The transformed data or None if no transformation is possible.
"""
try:
return Gui.__unsupported_data_converter(value) if callable(Gui.__unsupported_data_converter) else None # type: ignore
except Exception as e:
_warn(f"Error transforming data: {str(e)}")
return None

def __request_data_update(self, var_name: str, payload: t.Any) -> None:
# Use custom attrgetter function to allow value binding for _MapDict
newvalue = _getscopeattr_drill(self, var_name)
Expand Down
76 changes: 76 additions & 0 deletions tests/gui/data/test_accessors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from unittest.mock import MagicMock
from taipy.gui import Gui
from taipy.gui.data.data_accessor import (
_DataAccessor,
_DataAccessors,
_InvalidDataAccessor,
)
from taipy.gui.data.data_format import _DataFormat
from taipy.gui.utils.types import _TaipyData
from unittest.mock import Mock
import typing as t


class MyDataAccessor(_DataAccessor):
@staticmethod
def get_supported_classes() -> t.List[t.Type]:
return [int]

def get_data(
self,
var_name: str,
value: t.Any,
payload: t.Dict[str, t.Any],
data_format: _DataFormat,
) -> t.Dict[str, t.Any]:
return {"value": 2*int(value)}

def get_col_types(self, var_name: str, value: t.Any) -> t.Dict[str, str]: # type: ignore
pass

def to_pandas(self, value: t.Any) -> t.Union[t.List[t.Any], t.Any]:
pass

def on_edit(self, value: t.Any, payload: t.Dict[str, t.Any]) -> t.Optional[t.Any]:
pass

def on_delete(self, value: t.Any, payload: t.Dict[str, t.Any]) -> t.Optional[t.Any]:
pass

def on_add(
self,
value: t.Any,
payload: t.Dict[str, t.Any],
new_row: t.Optional[t.List[t.Any]] = None,
) -> t.Optional[t.Any]:
pass

def to_csv(self, var_name: str, value: t.Any) -> t.Optional[str]:
pass

def mock_taipy_data(value):
"""Helper to mock _TaipyData objects."""
mock_data = Mock(spec=_TaipyData)
mock_data.get.return_value = value
return mock_data


def test_custom_accessor(gui: Gui):
"""Test if get_data() uses the correct accessor."""
data_accessors = _DataAccessors(gui)
data = mock_taipy_data(123)

# Testing when accessor is not registered
data_accessor = data_accessors._DataAccessors__get_instance(mock_taipy_data) # type: ignore
assert isinstance(data_accessor, _InvalidDataAccessor), f"Expected _InvalidDataAccessor but got {type(data_accessor)}"
result = data_accessors.get_data("var_name", data, {})
assert result == {}

# Testing when accessor is not registered
data_accessors._register(MyDataAccessor)

result = data_accessors.get_data("var_name", data, {})
assert isinstance(result, dict)
assert result["value"] == 246

data_accessors._unregister(MyDataAccessor)
73 changes: 73 additions & 0 deletions tests/gui/data/test_unsupported_data_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging

import pytest

from taipy.gui import Gui


def test_handle_invalid_data_no_callback():
result = Gui._convert_unsupported_data("invalid_data")
assert result is None



def test_unsupported_data_converter_returns_none():
def convert(value):
return None # Simulates a failed transformation

Gui.set_unsupported_data_converter(convert)

result = Gui._convert_unsupported_data("invalid_data")
assert result is None

# Reset converter
Gui.set_unsupported_data_converter(None)


def test_unsupported_data_converter_applied():
def convert(value):
return "converted" # Successful transformation

Gui.set_unsupported_data_converter(convert)

result = Gui._convert_unsupported_data("raw data")
assert result == "converted"

# Reset converter
Gui.set_unsupported_data_converter(None)


def test_unsupported_data_converter_raises_exception(capfd, monkeypatch):
def convert(value):
raise ValueError("Conversion failure") # Simulate an error

def mock_warn(message: str):
logging.warning(message) # Ensure the warning goes to stderr.
# Patch the _warn function inside the taipy.gui._warnings module.
monkeypatch.setattr("taipy.gui._warnings._warn", mock_warn)

Gui.set_unsupported_data_converter(convert)

result = Gui._convert_unsupported_data("raw data")

out, _ = capfd.readouterr()

assert result is None # Should return None on exception
assert "Error transforming data: Transformation error"

# Reset converter
Gui.set_unsupported_data_converter(None)


@pytest.mark.parametrize("input_data", [None, 123, [], {}, set()])
def test_unsupported_data_converter_with_various_inputs(input_data):
def convert(value):
return "converted" # Always returns valid data

Gui.set_unsupported_data_converter(convert)

result = Gui._convert_unsupported_data(input_data)
assert result == "converted" # Transformed correctly for all inputs

# Reset converter
Gui.set_unsupported_data_converter(None)
Loading

0 comments on commit f7ab246

Please sign in to comment.