diff --git a/homeassistant/components/esphome/dashboard.py b/homeassistant/components/esphome/dashboard.py index b0a37aefd0d2d9..334c16e57301ad 100644 --- a/homeassistant/components/esphome/dashboard.py +++ b/homeassistant/components/esphome/dashboard.py @@ -12,6 +12,7 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.singleton import singleton from homeassistant.helpers.storage import Store +from homeassistant.util.hass_dict import HassKey from .const import DOMAIN from .coordinator import ESPHomeDashboardCoordinator @@ -19,7 +20,9 @@ _LOGGER = logging.getLogger(__name__) -KEY_DASHBOARD_MANAGER = "esphome_dashboard_manager" +KEY_DASHBOARD_MANAGER: HassKey[ESPHomeDashboardManager] = HassKey( + "esphome_dashboard_manager" +) STORAGE_KEY = "esphome.dashboard" STORAGE_VERSION = 1 @@ -33,7 +36,7 @@ async def async_setup(hass: HomeAssistant) -> None: await async_get_or_create_dashboard_manager(hass) -@singleton(KEY_DASHBOARD_MANAGER) +@singleton(KEY_DASHBOARD_MANAGER, async_=True) async def async_get_or_create_dashboard_manager( hass: HomeAssistant, ) -> ESPHomeDashboardManager: @@ -140,7 +143,7 @@ def async_get_dashboard(hass: HomeAssistant) -> ESPHomeDashboardCoordinator | No where manager can be an asyncio.Event instead of the actual manager because the singleton decorator is not yet done. """ - manager: ESPHomeDashboardManager | None = hass.data.get(KEY_DASHBOARD_MANAGER) + manager = hass.data.get(KEY_DASHBOARD_MANAGER) return manager.async_get() if manager else None diff --git a/homeassistant/helpers/singleton.py b/homeassistant/helpers/singleton.py index 20e4ee82162de9..075fc50b49af86 100644 --- a/homeassistant/helpers/singleton.py +++ b/homeassistant/helpers/singleton.py @@ -3,15 +3,22 @@ from __future__ import annotations import asyncio -from collections.abc import Callable +from collections.abc import Callable, Coroutine import functools -from typing import Any, cast, overload +from typing import Any, Literal, assert_type, cast, overload from homeassistant.core import HomeAssistant from homeassistant.loader import bind_hass from homeassistant.util.hass_dict import HassKey type _FuncType[_T] = Callable[[HomeAssistant], _T] +type _Coro[_T] = Coroutine[Any, Any, _T] + + +@overload +def singleton[_T]( + data_key: HassKey[_T], *, async_: Literal[True] +) -> Callable[[_FuncType[_Coro[_T]]], _FuncType[_Coro[_T]]]: ... @overload @@ -24,29 +31,37 @@ def singleton[_T]( def singleton[_T](data_key: str) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ... -def singleton[_T](data_key: Any) -> Callable[[_FuncType[_T]], _FuncType[_T]]: +def singleton[_S, _T, _U]( + data_key: Any, *, async_: bool = False +) -> Callable[[_FuncType[_S]], _FuncType[_S]]: """Decorate a function that should be called once per instance. Result will be cached and simultaneous calls will be handled. """ - def wrapper(func: _FuncType[_T]) -> _FuncType[_T]: + @overload + def wrapper(func: _FuncType[_Coro[_T]]) -> _FuncType[_Coro[_T]]: ... + + @overload + def wrapper(func: _FuncType[_U]) -> _FuncType[_U]: ... + + def wrapper(func: _FuncType[_Coro[_T] | _U]) -> _FuncType[_Coro[_T] | _U]: """Wrap a function with caching logic.""" if not asyncio.iscoroutinefunction(func): @functools.lru_cache(maxsize=1) @bind_hass @functools.wraps(func) - def wrapped(hass: HomeAssistant) -> _T: + def wrapped(hass: HomeAssistant) -> _U: if data_key not in hass.data: hass.data[data_key] = func(hass) - return cast(_T, hass.data[data_key]) + return cast(_U, hass.data[data_key]) return wrapped @bind_hass @functools.wraps(func) - async def async_wrapped(hass: HomeAssistant) -> Any: + async def async_wrapped(hass: HomeAssistant) -> _T: if data_key not in hass.data: evt = hass.data[data_key] = asyncio.Event() result = await func(hass) @@ -62,6 +77,45 @@ async def async_wrapped(hass: HomeAssistant) -> Any: return cast(_T, obj_or_evt) - return async_wrapped # type: ignore[return-value] + return async_wrapped return wrapper + + +async def _test_singleton_typing(hass: HomeAssistant) -> None: + """Test singleton overloads work as intended. + + This is tested during the mypy run. Do not move it to 'tests'! + """ + # Test HassKey + key = HassKey[int]("key") + + @singleton(key) + def func(hass: HomeAssistant) -> int: + return 2 + + @singleton(key, async_=True) + async def async_func(hass: HomeAssistant) -> int: + return 2 + + assert_type(func(hass), int) + assert_type(await async_func(hass), int) + + # Test invalid use of 'async_' with sync function + @singleton(key, async_=True) # type: ignore[arg-type] + def func_error(hass: HomeAssistant) -> int: + return 2 + + # Test string key + other_key = "key" + + @singleton(other_key) + def func2(hass: HomeAssistant) -> str: + return "" + + @singleton(other_key) + async def async_func2(hass: HomeAssistant) -> str: + return "" + + assert_type(func2(hass), str) + assert_type(await async_func2(hass), str)