diff --git a/taipy/gui/gui.py b/taipy/gui/gui.py index e5257a62d..2be2076c7 100644 --- a/taipy/gui/gui.py +++ b/taipy/gui/gui.py @@ -25,7 +25,7 @@ import warnings from importlib import metadata, util from importlib.util import find_spec -from inspect import currentframe, getabsfile, ismethod, ismodule +from inspect import currentframe, getabsfile, iscoroutinefunction, ismethod, ismodule from pathlib import Path from threading import Thread, Timer from types import FrameType, FunctionType, LambdaType, ModuleType, SimpleNamespace @@ -73,7 +73,7 @@ from .page import Page from .partial import Partial from .server import _Server -from .state import State, _GuiState +from .state import State, _AsyncState, _GuiState from .types import _WsType from .utils import ( _delscopeattr, @@ -115,6 +115,7 @@ from .utils._variable_directory import _is_moduled_variable, _VariableDirectory from .utils.chart_config_builder import _build_chart_config from .utils.table_col_builder import _enhance_columns +from .utils.threads import _invoke_async_callback class Gui: @@ -1143,7 +1144,6 @@ def __update_state_context(self, payload: dict): for var, val in state_context.items(): self._update_var(var, val, True, forward=False) - @staticmethod def set_unsupported_data_converter(converter: t.Optional[t.Callable[[t.Any], t.Any]]) -> None: """Set a custom converter for unsupported data types. @@ -1588,7 +1588,12 @@ def __call_function_with_args(self, **kwargs): def _call_function_with_state(self, user_function: t.Callable, args: t.Optional[t.List[t.Any]] = None) -> t.Any: cp_args = [] if args is None else args.copy() - cp_args.insert(0, self.__get_state()) + cp_args.insert( + 0, + _AsyncState(t.cast(_GuiState, self.__get_state())) + if iscoroutinefunction(user_function) + else self.__get_state(), + ) argcount = user_function.__code__.co_argcount if argcount > 0 and ismethod(user_function): argcount -= 1 @@ -1597,7 +1602,10 @@ def _call_function_with_state(self, user_function: t.Callable, args: t.Optional[ else: cp_args = cp_args[:argcount] with self.__event_manager: - return user_function(*cp_args) + if iscoroutinefunction(user_function): + return _invoke_async_callback(user_function, cp_args) + else: + return user_function(*cp_args) def _set_module_context(self, module_context: t.Optional[str]) -> t.ContextManager[None]: return self._set_locals_context(module_context) if module_context is not None else contextlib.nullcontext() diff --git a/taipy/gui/gui_actions.py b/taipy/gui/gui_actions.py index 979c14156..fa65dd50b 100644 --- a/taipy/gui/gui_actions.py +++ b/taipy/gui/gui_actions.py @@ -15,6 +15,7 @@ from ._warnings import _warn from .gui import Gui from .state import State +from .utils.callable import _is_function def download( @@ -382,19 +383,20 @@ def invoke_long_callback( """ if not state or not isinstance(state._gui, Gui): _warn("'invoke_long_callback()' must be called in the context of a callback.") + return if user_status_function_args is None: user_status_function_args = [] if user_function_args is None: user_function_args = [] - state_id = get_state_id(state) - module_context = get_module_context(state) + this_gui = state.get_gui() + + state_id = this_gui._get_client_id() + module_context = this_gui._get_locals_context() if not isinstance(state_id, str) or not isinstance(module_context, str): return - this_gui = state._gui - def callback_on_exception(state: State, function_name: str, e: Exception): if not this_gui._call_on_exception(function_name, e): _warn(f"invoke_long_callback(): Exception raised in function {function_name}()", e) @@ -405,10 +407,10 @@ def callback_on_status( function_name: t.Optional[str] = None, function_result: t.Optional[t.Any] = None, ): - if callable(user_status_function): + if _is_function(user_status_function): this_gui.invoke_callback( str(state_id), - user_status_function, + t.cast(t.Callable, user_status_function), [status] + list(user_status_function_args) + [function_result], # type: ignore str(module_context), ) @@ -438,5 +440,5 @@ def thread_status(name: str, period_s: float, count: int): thread = threading.Thread(target=user_function_in_thread, args=user_function_args) thread.start() - if isinstance(period, int) and period >= 500 and callable(user_status_function): + if isinstance(period, int) and period >= 500 and _is_function(user_status_function): thread_status(thread.name, period / 1000.0, 0) diff --git a/taipy/gui/state.py b/taipy/gui/state.py index 32b2d248d..7aad4219f 100644 --- a/taipy/gui/state.py +++ b/taipy/gui/state.py @@ -171,10 +171,7 @@ class _GuiState(State): "_get_placeholder_attrs", "_add_attribute", ) - __placeholder_attrs = ( - "_taipy_p1", - "_current_context", - ) + __placeholder_attrs = ("_taipy_p1", "_current_context", "__state_id") __excluded_attrs = __attrs + __methods + __placeholder_attrs def __init__(self, gui: "Gui", var_list: t.Iterable[str], context_list: t.Iterable[str]) -> None: @@ -278,3 +275,27 @@ def _add_attribute(self, name: str, default_value: t.Optional[t.Any] = None) -> gui = super().__getattribute__(_GuiState.__gui_attr) return gui._bind_var_val(name, default_value) return False + + +class _AsyncState(_GuiState): + def __init__(self, state: State) -> None: + super().__init__(state.get_gui(), [], []) + self._set_placeholder("__state_id", state.get_gui()._get_client_id()) + + @staticmethod + def __set_var_in_state(state: State, var_name: str, value: t.Any): + setattr(state, var_name, value) + + @staticmethod + def __get_var_from_state(state: State, var_name: str): + return getattr(state, var_name) + + def __setattr__(self, var_name: str, var_value: t.Any) -> None: + self.get_gui().invoke_callback( + t.cast(str, self._get_placeholder("__state_id")), _AsyncState.__set_var_in_state, [var_name, var_value] + ) + + def __getattr__(self, var_name: str) -> t.Any: + return self.get_gui().invoke_callback( + t.cast(str, self._get_placeholder("__state_id")), _AsyncState.__get_var_from_state, [var_name] + ) diff --git a/taipy/gui/utils/threads.py b/taipy/gui/utils/threads.py new file mode 100644 index 000000000..36e0e2af4 --- /dev/null +++ b/taipy/gui/utils/threads.py @@ -0,0 +1,22 @@ +# Copyright 2021-2024 Avaiga Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. +import asyncio +import threading +import typing as t + + +def _thread_async_target(user_function, args: t.List[t.Any]): + asyncio.run(user_function(*args)) + + +def _invoke_async_callback(user_function, args: t.List[t.Any]): + thread = threading.Thread(target=_thread_async_target, args=[user_function, args]) + thread.start()