diff --git a/tests/test_device.py b/tests/test_device.py index f9666627f..79d3ee572 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -120,7 +120,7 @@ async def ota_zha_device( async def _send_time_changed(zha_gateway: Gateway, seconds: int): """Send a time changed event.""" await asyncio.sleep(seconds) - await zha_gateway.async_block_till_done() + await zha_gateway.async_block_till_done(wait_background_tasks=True) @patch( @@ -132,11 +132,15 @@ async def test_check_available_success( zha_gateway: Gateway, device_with_basic_cluster_handler: ZigpyDevice, # pylint: disable=redefined-outer-name device_joined: Callable[[ZigpyDevice], Awaitable[Device]], + caplog: pytest.LogCaptureFixture, ) -> None: """Check device availability success on 1st try.""" zha_device = await device_joined(device_with_basic_cluster_handler) basic_ch = device_with_basic_cluster_handler.endpoints[3].basic + assert not zha_device.is_coordinator + assert not zha_device.is_active_coordinator + basic_ch.read_attributes.reset_mock() device_with_basic_cluster_handler.last_seen = None assert zha_device.available is True @@ -156,22 +160,46 @@ def _update_last_seen(*args, **kwargs): # pylint: disable=unused-argument basic_ch.read_attributes.side_effect = _update_last_seen # successfully ping zigpy device, but zha_device is not yet available - await _send_time_changed(zha_gateway, zha_device.__polling_interval + 1) + await _send_time_changed( + zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + ) assert basic_ch.read_attributes.await_count == 1 assert basic_ch.read_attributes.await_args[0][0] == ["manufacturer"] assert zha_device.available is False # There was traffic from the device: pings, but not yet available - await _send_time_changed(zha_gateway, zha_device.__polling_interval + 1) + await _send_time_changed( + zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + ) assert basic_ch.read_attributes.await_count == 2 assert basic_ch.read_attributes.await_args[0][0] == ["manufacturer"] assert zha_device.available is False # There was traffic from the device: don't try to ping, marked as available - await _send_time_changed(zha_gateway, zha_device.__polling_interval + 1) + await _send_time_changed( + zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + ) assert basic_ch.read_attributes.await_count == 2 assert basic_ch.read_attributes.await_args[0][0] == ["manufacturer"] assert zha_device.available is True + assert zha_device.on_network is True + + assert "Device is not on the network, marking unavailable" not in caplog.text + zha_device.on_network = False + + assert zha_device.available is False + assert zha_device.on_network is False + + sleep_time = max( + zha_gateway.global_updater.__polling_interval, + zha_gateway._device_availability_checker.__polling_interval, + ) + sleep_time += 2 + + await asyncio.sleep(sleep_time) + await zha_gateway.async_block_till_done(wait_background_tasks=True) + + assert "Device is not on the network, marking unavailable" in caplog.text @patch( @@ -197,21 +225,27 @@ async def test_check_available_unsuccessful( ) # unsuccessfully ping zigpy device, but zha_device is still available - await _send_time_changed(zha_gateway, zha_device.__polling_interval + 1) + await _send_time_changed( + zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + ) assert basic_ch.read_attributes.await_count == 1 assert basic_ch.read_attributes.await_args[0][0] == ["manufacturer"] assert zha_device.available is True # still no traffic, but zha_device is still available - await _send_time_changed(zha_gateway, zha_device.__polling_interval + 1) + await _send_time_changed( + zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + ) assert basic_ch.read_attributes.await_count == 2 assert basic_ch.read_attributes.await_args[0][0] == ["manufacturer"] assert zha_device.available is True # not even trying to update, device is unavailable - await _send_time_changed(zha_gateway, zha_device.__polling_interval + 1) + await _send_time_changed( + zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + ) assert basic_ch.read_attributes.await_count == 2 assert basic_ch.read_attributes.await_args[0][0] == ["manufacturer"] @@ -241,7 +275,9 @@ async def test_check_available_no_basic_cluster_handler( ) assert "does not have a mandatory basic cluster" not in caplog.text - await _send_time_changed(zha_gateway, zha_device.__polling_interval + 1) + await _send_time_changed( + zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + ) assert zha_device.available is False assert "does not have a mandatory basic cluster" in caplog.text diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 61621b5c0..fee532b9b 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -173,6 +173,7 @@ async def test_device_left( zha_gateway.device_left(zigpy_dev_basic) await zha_gateway.async_block_till_done() assert zha_dev_basic.available is False + assert zha_dev_basic.on_network is False async def test_gateway_group_methods( @@ -390,6 +391,7 @@ async def test_gateway_force_multi_pan_channel( @pytest.mark.parametrize("radio_concurrency", [1, 2, 8]) +@pytest.mark.looptime async def test_startup_concurrency_limit( radio_concurrency: int, zigpy_app_controller: ControllerApplication, @@ -567,3 +569,57 @@ def test_gateway_raw_device_initialized( event="raw_device_initialized", ), ) + + +@pytest.mark.looptime +async def test_pollers_skip( + zha_gateway: Gateway, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test pollers skip when they should.""" + + assert "Global updater interval skipped" not in caplog.text + assert "Device availability checker interval skipped" not in caplog.text + + assert zha_gateway.config.allow_polling is True + zha_gateway.config.allow_polling = False + assert zha_gateway.config.allow_polling is False + + sleep_time = max( + zha_gateway.global_updater.__polling_interval, + zha_gateway._device_availability_checker.__polling_interval, + ) + sleep_time += 2 + + await asyncio.sleep(sleep_time) + await zha_gateway.async_block_till_done(wait_background_tasks=True) + + assert "Global updater interval skipped" in caplog.text + assert "Device availability checker interval skipped" in caplog.text + + +async def test_gateway_handle_message( + zha_gateway: Gateway, + zha_dev_basic: Device, # pylint: disable=redefined-outer-name +) -> None: + """Test handle message.""" + + assert zha_dev_basic.available is True + assert zha_dev_basic.on_network is True + + zha_dev_basic.on_network = False + + assert zha_dev_basic.available is False + assert zha_dev_basic.on_network is False + + zha_gateway.handle_message( + zha_dev_basic.device, + zha.PROFILE_ID, + general.Basic.cluster_id, + 1, + 1, + b"", + ) + + assert zha_dev_basic.available is True + assert zha_dev_basic.on_network is True diff --git a/tests/test_sensor.py b/tests/test_sensor.py index 5637e4f88..141a3c8b0 100644 --- a/tests/test_sensor.py +++ b/tests/test_sensor.py @@ -643,15 +643,16 @@ async def test_electrical_measurement_init( assert cluster_handler.ac_power_multiplier == 1 assert entity.state["state"] == 4.0 - zha_device.available = False + zha_device.on_network = False - await asyncio.sleep(70) + await asyncio.sleep(entity.__polling_interval + 1) + await zha_gateway.async_block_till_done(wait_background_tasks=True) assert ( "1-2820: skipping polling for updated state, available: False, allow polled requests: True" in caplog.text ) - zha_device.available = True + zha_device.on_network = True await send_attributes_report( zha_gateway, @@ -1173,13 +1174,14 @@ async def test_device_counter_sensors( "counter_1" ].increment() - await entity.async_update() - await zha_gateway.async_block_till_done() + await asyncio.sleep(zha_gateway.global_updater.__polling_interval + 2) + await zha_gateway.async_block_till_done(wait_background_tasks=True) assert entity.state["state"] == 2 coordinator.available = False await asyncio.sleep(120) + await zha_gateway.async_block_till_done(wait_background_tasks=True) assert ( "counter_1: skipping polling for updated state, available: False, allow polled requests: True" diff --git a/zha/application/gateway.py b/zha/application/gateway.py index 9eeb48963..b6456c9a5 100644 --- a/zha/application/gateway.py +++ b/zha/application/gateway.py @@ -42,7 +42,7 @@ ZHA_GW_MSG_RAW_INIT, RadioType, ) -from zha.application.helpers import ZHAData +from zha.application.helpers import DeviceAvailabilityChecker, GlobalUpdater, ZHAData from zha.async_ import ( AsyncUtilMixin, create_eager_task, @@ -162,6 +162,10 @@ def __init__(self, config: ZHAData) -> None: setup_quirks( custom_quirks_path=config.yaml_config.get(CONF_CUSTOM_QUIRKS_PATH) ) + self.global_updater: GlobalUpdater = GlobalUpdater(self) + self._device_availability_checker: DeviceAvailabilityChecker = ( + DeviceAvailabilityChecker(self) + ) self.config.gateway = self def get_application_controller_data(self) -> tuple[ControllerApplication, dict]: @@ -224,6 +228,8 @@ async def async_initialize(self) -> None: self.application_controller.add_listener(self) self.application_controller.groups.add_listener(self) + self.global_updater.start() + self._device_availability_checker.start() def connection_lost(self, exc: Exception) -> None: """Handle connection lost event.""" @@ -408,7 +414,10 @@ def device_initialized(self, device: zigpy.device.Device) -> None: def device_left(self, device: zigpy.device.Device) -> None: """Handle device leaving the network.""" - self.async_update_device(device, False) + zha_device: Device = self._devices.get(device.ieee) + if zha_device is not None: + zha_device.on_network = False + self.async_update_device(device, available=False) def group_member_removed( self, zigpy_group: zigpy.group.Group, endpoint: zigpy.endpoint.Endpoint @@ -522,7 +531,9 @@ def get_or_create_group(self, zigpy_group: zigpy.group.Group) -> Group: return zha_group def async_update_device( - self, sender: zigpy.device.Device, available: bool = True + self, + sender: zigpy.device.Device, + available: bool = True, ) -> None: """Update device that has just become available.""" if sender.ieee in self.devices: @@ -569,6 +580,7 @@ async def async_device_initialized(self, device: zigpy.device.Device) -> None: async def _async_device_joined(self, zha_device: Device) -> None: zha_device.available = True + zha_device.on_network = True await zha_device.async_configure() device_info = ExtendedDeviceInfoWithPairingStatus( pairing_status=DevicePairingStatus.CONFIGURED.name, @@ -600,7 +612,7 @@ async def _async_device_rejoined(self, zha_device: Device) -> None: ) # force async_initialize() to fire so don't explicitly call it zha_device.available = False - zha_device.update_available(True) + zha_device.on_network = True async def async_create_zigpy_group( self, @@ -660,6 +672,9 @@ async def shutdown(self) -> None: _LOGGER.debug("Ignoring duplicate shutdown event") return + self.global_updater.stop() + self._device_availability_checker.stop() + async def _cancel_tasks(tasks_to_cancel: Iterable) -> None: tasks = [t for t in tasks_to_cancel if not (t.done() or t.cancelled())] for task in tasks: @@ -696,4 +711,5 @@ def handle_message( # pylint: disable=unused-argument ) -> None: """Handle message from a device Event handler.""" if sender.ieee in self.devices and not self.devices[sender.ieee].available: + self.devices[sender.ieee].on_network = True self.async_update_device(sender, available=True) diff --git a/zha/application/helpers.py b/zha/application/helpers.py index bb6a5e387..91e1e5658 100644 --- a/zha/application/helpers.py +++ b/zha/application/helpers.py @@ -2,8 +2,10 @@ from __future__ import annotations +import asyncio import binascii import collections +from collections.abc import Callable import dataclasses from dataclasses import dataclass import enum @@ -25,7 +27,8 @@ CLUSTER_TYPE_OUT, CUSTOM_CONFIGURATION, ) -from zha.decorators import SetRegistry +from zha.async_ import gather_with_limited_concurrency +from zha.decorators import SetRegistry, callback, periodic # from zha.zigbee.cluster_handlers.registries import BINDABLE_CLUSTERS BINDABLE_CLUSTERS = SetRegistry() @@ -277,3 +280,116 @@ class ZHAData: default_factory=dict ) allow_polling: bool = dataclasses.field(default=False) + + +class GlobalUpdater: + """Global updater for ZHA. + + This class is used to update all listeners at a regular interval. The listeners + are `Callable` objects that are registered with the `register_update_listener` method. + """ + + _REFRESH_INTERVAL = (30, 45) + __polling_interval: int + + def __init__(self, gateway: Gateway): + """Initialize the GlobalUpdater.""" + self._updater_task_handle: asyncio.Task = None + self._update_listeners: list[Callable] = [] + self._gateway: Gateway = gateway + + def start(self): + """Start the global updater.""" + self._updater_task_handle = self._gateway.async_create_background_task( + self.update_listeners(), + name=f"global-updater_{self.__class__.__name__}", + eager_start=True, + untracked=True, + ) + _LOGGER.debug( + "started global updater with an interval of %s seconds", + getattr(self, "__polling_interval"), + ) + + def stop(self): + """Stop the global updater.""" + _LOGGER.debug("stopping global updater") + if self._updater_task_handle: + self._updater_task_handle.cancel() + self._updater_task_handle = None + _LOGGER.debug("global updater stopped") + + def register_update_listener(self, listener: Callable): + """Register an update listener.""" + self._update_listeners.append(listener) + + def remove_update_listener(self, listener: Callable): + """Remove an update listener.""" + self._update_listeners.remove(listener) + + @callback + @periodic(_REFRESH_INTERVAL) + async def update_listeners(self): + """Update all listeners.""" + _LOGGER.debug("Global updater interval starting") + if self._gateway.config.allow_polling: + for listener in self._update_listeners: + _LOGGER.debug("Global updater running update callback") + listener() + else: + _LOGGER.debug("Global updater interval skipped") + _LOGGER.debug("Global updater interval finished") + + +class DeviceAvailabilityChecker: + """Device availability checker for ZHA.""" + + _REFRESH_INTERVAL = (30, 45) + __polling_interval: int + + def __init__(self, gateway: Gateway): + """Initialize the DeviceAvailabilityChecker.""" + self._gateway: Gateway = gateway + self._device_availability_task_handle: asyncio.Task = None + + def start(self): + """Start the device availability checker.""" + self._device_availability_task_handle = ( + self._gateway.async_create_background_task( + self.check_device_availability(), + name=f"device-availability-checker_{self.__class__.__name__}", + eager_start=True, + untracked=True, + ) + ) + _LOGGER.debug( + "started device availability checker with an interval of %s seconds", + getattr(self, "__polling_interval"), + ) + + def stop(self): + """Stop the device availability checker.""" + _LOGGER.debug("stopping device availability checker") + if self._device_availability_task_handle: + self._device_availability_task_handle.cancel() + self._device_availability_task_handle = None + _LOGGER.debug("device availability checker stopped") + + @periodic(_REFRESH_INTERVAL) + async def check_device_availability(self): + """Check device availability.""" + _LOGGER.debug("Device availability checker interval starting") + if self._gateway.config.allow_polling: + _LOGGER.debug("Checking device availability") + # 20 because most devices will not make remote calls + await gather_with_limited_concurrency( + 20, + *( + dev._check_available() + for dev in self._gateway.devices.values() + if not dev.is_coordinator + ), + ) + _LOGGER.debug("Device availability checker interval finished") + else: + _LOGGER.debug("Device availability checker interval skipped") diff --git a/zha/application/platforms/sensor/__init__.py b/zha/application/platforms/sensor/__init__.py index 49eb33b7c..8550f5de2 100644 --- a/zha/application/platforms/sensor/__init__.py +++ b/zha/application/platforms/sensor/__init__.py @@ -319,9 +319,6 @@ class DeviceCounterSensor(BaseEntity): """Device counter sensor.""" PLATFORM = Platform.SENSOR - _REFRESH_INTERVAL = (30, 45) - __polling_interval: int - _use_custom_polling: bool = True _attr_state_class: SensorStateClass = SensorStateClass.TOTAL _attr_entity_category = EntityCategory.DIAGNOSTIC _attr_entity_registry_enabled_default = False @@ -364,18 +361,7 @@ def __init__( self._zigpy_counter_groups: str = counter_groups self._zigpy_counter_group: str = counter_group self._attr_name: str = self._zigpy_counter.name - self._tracked_tasks.append( - self._device.gateway.async_create_background_task( - self._refresh(), - name=f"sensor_state_poller_{self.unique_id}_{self.__class__.__name__}", - eager_start=True, - untracked=True, - ) - ) - self.debug( - "started polling with refresh interval of %s", - getattr(self, "__polling_interval"), - ) + self._device.gateway.global_updater.register_update_listener(self.update) # we double create these in discovery tests because we reissue the create calls to count and prove them out if self.unique_id not in self._device.platform_entities: self._device.platform_entities[self.unique_id] = self @@ -428,16 +414,11 @@ def device(self) -> Device: """Return the device.""" return self._device - async def async_update(self) -> None: - """Retrieve latest state.""" - self.maybe_emit_state_changed_event() - - @periodic(_REFRESH_INTERVAL) - async def _refresh(self): + def update(self): """Call async_update at a constrained random interval.""" if self._device.available and self._device.gateway.config.allow_polling: self.debug("polling for updated state") - await self.async_update() + self.maybe_emit_state_changed_event() else: self.debug( "skipping polling for updated state, available: %s, allow polled requests: %s", @@ -445,6 +426,11 @@ async def _refresh(self): self._device.gateway.config.allow_polling, ) + async def on_remove(self) -> None: + """Cancel tasks this entity owns.""" + self._device.gateway.global_updater.remove_update_listener(self.update) + await super().on_remove() + class EnumSensor(Sensor): """Sensor with value from enum.""" @@ -1313,7 +1299,6 @@ class RSSISensor(Sensor): _attr_native_unit_of_measurement: str | None = SIGNAL_STRENGTH_DECIBELS_MILLIWATT _attr_entity_category = EntityCategory.DIAGNOSTIC _attr_entity_registry_enabled_default = False - _attr_should_poll = True # BaseZhaEntity defaults to False _attr_translation_key: str = "rssi" @classmethod @@ -1334,6 +1319,18 @@ def create_platform_entity( return None return cls(unique_id, cluster_handlers, endpoint, device, **kwargs) + def __init__( + self, + unique_id: str, + cluster_handlers: list[ClusterHandler], + endpoint: Endpoint, + device: Device, + **kwargs: Any, + ) -> None: + """Init.""" + super().__init__(unique_id, cluster_handlers, endpoint, device, **kwargs) + self.device.gateway.global_updater.register_update_listener(self.update) + @property def state(self) -> dict: """Return the state of the sensor.""" @@ -1346,6 +1343,23 @@ def native_value(self) -> str | int | float | None: """Return the state of the entity.""" return getattr(self._device.device, self._unique_id_suffix) + def update(self): + """Call async_update at a constrained random interval.""" + if self._device.available and self._device.gateway.config.allow_polling: + self.debug("polling for updated state") + self.maybe_emit_state_changed_event() + else: + self.debug( + "skipping polling for updated state, available: %s, allow polled requests: %s", + self._device.available, + self._device.gateway.config.allow_polling, + ) + + async def on_remove(self) -> None: + """Cancel tasks this entity owns.""" + self._device.gateway.global_updater.remove_update_listener(self.update) + await super().on_remove() + @MULTI_MATCH(cluster_handler_names=CLUSTER_HANDLER_BASIC) class LQISensor(RSSISensor): diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index 93bac8547..778e6d6e4 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -67,7 +67,6 @@ ) from zha.application.helpers import async_get_zha_config_value, convert_to_zcl_values from zha.application.platforms import PlatformEntity, PlatformEntityInfo -from zha.decorators import periodic from zha.event import EventBase from zha.exceptions import ZHAException from zha.mixins import LogMixin @@ -200,7 +199,6 @@ class ExtendedDeviceInfo(DeviceInfo): class Device(LogMixin, EventBase): """ZHA Zigbee device object.""" - __polling_interval: int _ha_device_id: str def __init__( @@ -244,6 +242,7 @@ def __init__( and time.time() - self.last_seen < self.consider_unavailable_time ) self._checkins_missed_count: int = 0 + self._on_network: bool = False self._platform_entities: dict[str, PlatformEntity] = {} self.semaphore: asyncio.Semaphore = asyncio.Semaphore(3) @@ -256,20 +255,6 @@ def __init__( if ep_id != 0: self._endpoints[ep_id] = Endpoint.new(endpoint, self) - if not self.is_coordinator: - self._tracked_tasks.append( - self.gateway.async_create_background_task( - self._check_available(), - name=f"device_check_alive_{self.ieee}", - eager_start=True, - untracked=True, - ) - ) - self.debug( - "starting availability checks - interval: %s", - getattr(self, "__polling_interval"), - ) - @cached_property def device(self) -> zigpy.device.Device: """Return underlying Zigpy device.""" @@ -416,13 +401,24 @@ def device_automation_triggers(self) -> dict[tuple[str, str], dict[str, str]]: @property def available(self): """Return True if device is available.""" - return self._available + return self._available and self._on_network @available.setter def available(self, new_availability: bool) -> None: """Set device availability.""" self._available = new_availability + @property + def on_network(self): + """Return True if device is currently on the network.""" + return self._on_network + + @on_network.setter + def on_network(self, new_on_network: bool) -> None: + """Set device on_network flag.""" + self._on_network = new_on_network + self.update_available(new_on_network) + @property def power_configuration_ch(self) -> ClusterHandler | None: """Return power configuration cluster handler.""" @@ -523,11 +519,14 @@ def async_update_sw_build_id(self, sw_version: int) -> None: """Update device sw version.""" self._sw_build_id = sw_version - @periodic(_UPDATE_ALIVE_INTERVAL) async def _check_available(self, *_: Any) -> None: # don't flip the availability state of the coordinator if self.is_coordinator: return + if not self._on_network: + self.debug("Device is not on the network, marking unavailable") + self.update_available(False) + return if self.last_seen is None: self.debug("last_seen is None, marking the device unavailable") self.update_available(False) @@ -592,8 +591,10 @@ def update_available(self, available: bool) -> None: "Device availability changed and device became available," " reinitializing cluster handlers" ) - self._gateway.track_task( - asyncio.create_task(self._async_became_available()) + self._gateway.async_create_task( + self._async_became_available(), + name=f"({self.nwk},{self.model})_async_became_available", + eager_start=True, ) return if availability_changed and not available: