Skip to content

Commit

Permalink
support async callback with state
Browse files Browse the repository at this point in the history
resolves #2288
  • Loading branch information
Fred Lefévère-Laoide authored and Fred Lefévère-Laoide committed Nov 29, 2024
1 parent 45cb211 commit 01607e9
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 16 deletions.
18 changes: 13 additions & 5 deletions taipy/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import warnings
from importlib import metadata, util
from importlib.util import find_spec
from inspect import currentframe, getabsfile, ismethod, ismodule
from inspect import currentframe, getabsfile, iscoroutinefunction, ismethod, ismodule
from pathlib import Path
from threading import Thread, Timer
from types import FrameType, FunctionType, LambdaType, ModuleType, SimpleNamespace
Expand Down Expand Up @@ -73,7 +73,7 @@
from .page import Page
from .partial import Partial
from .server import _Server
from .state import State, _GuiState
from .state import State, _AsyncState, _GuiState
from .types import _WsType
from .utils import (
_delscopeattr,
Expand Down Expand Up @@ -115,6 +115,7 @@
from .utils._variable_directory import _is_moduled_variable, _VariableDirectory
from .utils.chart_config_builder import _build_chart_config
from .utils.table_col_builder import _enhance_columns
from .utils.threads import _invoke_async_callback


class Gui:
Expand Down Expand Up @@ -1143,7 +1144,6 @@ def __update_state_context(self, payload: dict):
for var, val in state_context.items():
self._update_var(var, val, True, forward=False)


@staticmethod
def set_unsupported_data_converter(converter: t.Optional[t.Callable[[t.Any], t.Any]]) -> None:
"""Set a custom converter for unsupported data types.
Expand Down Expand Up @@ -1588,7 +1588,12 @@ def __call_function_with_args(self, **kwargs):

def _call_function_with_state(self, user_function: t.Callable, args: t.Optional[t.List[t.Any]] = None) -> t.Any:
cp_args = [] if args is None else args.copy()
cp_args.insert(0, self.__get_state())
cp_args.insert(
0,
_AsyncState(t.cast(_GuiState, self.__get_state()))
if iscoroutinefunction(user_function)
else self.__get_state(),
)
argcount = user_function.__code__.co_argcount
if argcount > 0 and ismethod(user_function):
argcount -= 1
Expand All @@ -1597,7 +1602,10 @@ def _call_function_with_state(self, user_function: t.Callable, args: t.Optional[
else:
cp_args = cp_args[:argcount]
with self.__event_manager:
return user_function(*cp_args)
if iscoroutinefunction(user_function):
return _invoke_async_callback(user_function, cp_args)
else:
return user_function(*cp_args)

def _set_module_context(self, module_context: t.Optional[str]) -> t.ContextManager[None]:
return self._set_locals_context(module_context) if module_context is not None else contextlib.nullcontext()
Expand Down
16 changes: 9 additions & 7 deletions taipy/gui/gui_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ._warnings import _warn
from .gui import Gui
from .state import State
from .utils.callable import _is_function


def download(
Expand Down Expand Up @@ -382,19 +383,20 @@ def invoke_long_callback(
"""
if not state or not isinstance(state._gui, Gui):
_warn("'invoke_long_callback()' must be called in the context of a callback.")
return

if user_status_function_args is None:
user_status_function_args = []
if user_function_args is None:
user_function_args = []

state_id = get_state_id(state)
module_context = get_module_context(state)
this_gui = state.get_gui()

state_id = this_gui._get_client_id()
module_context = this_gui._get_locals_context()
if not isinstance(state_id, str) or not isinstance(module_context, str):
return

this_gui = state._gui

def callback_on_exception(state: State, function_name: str, e: Exception):
if not this_gui._call_on_exception(function_name, e):
_warn(f"invoke_long_callback(): Exception raised in function {function_name}()", e)
Expand All @@ -405,10 +407,10 @@ def callback_on_status(
function_name: t.Optional[str] = None,
function_result: t.Optional[t.Any] = None,
):
if callable(user_status_function):
if _is_function(user_status_function):
this_gui.invoke_callback(
str(state_id),
user_status_function,
t.cast(t.Callable, user_status_function),
[status] + list(user_status_function_args) + [function_result], # type: ignore
str(module_context),
)
Expand Down Expand Up @@ -438,5 +440,5 @@ def thread_status(name: str, period_s: float, count: int):

thread = threading.Thread(target=user_function_in_thread, args=user_function_args)
thread.start()
if isinstance(period, int) and period >= 500 and callable(user_status_function):
if isinstance(period, int) and period >= 500 and _is_function(user_status_function):
thread_status(thread.name, period / 1000.0, 0)
29 changes: 25 additions & 4 deletions taipy/gui/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,7 @@ class _GuiState(State):
"_get_placeholder_attrs",
"_add_attribute",
)
__placeholder_attrs = (
"_taipy_p1",
"_current_context",
)
__placeholder_attrs = ("_taipy_p1", "_current_context", "__state_id")
__excluded_attrs = __attrs + __methods + __placeholder_attrs

def __init__(self, gui: "Gui", var_list: t.Iterable[str], context_list: t.Iterable[str]) -> None:
Expand Down Expand Up @@ -278,3 +275,27 @@ def _add_attribute(self, name: str, default_value: t.Optional[t.Any] = None) ->
gui = super().__getattribute__(_GuiState.__gui_attr)
return gui._bind_var_val(name, default_value)
return False


class _AsyncState(_GuiState):
def __init__(self, state: State) -> None:
super().__init__(state.get_gui(), [], [])
self._set_placeholder("__state_id", state.get_gui()._get_client_id())

@staticmethod
def __set_var_in_state(state: State, var_name: str, value: t.Any):
setattr(state, var_name, value)

@staticmethod
def __get_var_from_state(state: State, var_name: str):
return getattr(state, var_name)

def __setattr__(self, var_name: str, var_value: t.Any) -> None:
self.get_gui().invoke_callback(
t.cast(str, self._get_placeholder("__state_id")), _AsyncState.__set_var_in_state, [var_name, var_value]
)

def __getattr__(self, var_name: str) -> t.Any:
return self.get_gui().invoke_callback(
t.cast(str, self._get_placeholder("__state_id")), _AsyncState.__get_var_from_state, [var_name]
)
22 changes: 22 additions & 0 deletions taipy/gui/utils/threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2021-2024 Avaiga Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
import asyncio
import threading
import typing as t


def _thread_async_target(user_function, args: t.List[t.Any]):
asyncio.run(user_function(*args))


def _invoke_async_callback(user_function, args: t.List[t.Any]):
thread = threading.Thread(target=_thread_async_target, args=[user_function, args])
thread.start()

0 comments on commit 01607e9

Please sign in to comment.