Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Let cluster config/init bubble up exception groups #120

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from zigpy.quirks import get_device
import zigpy.types
from zigpy.zcl.clusters.general import Basic, Groups
from zigpy.zcl.foundation import Status
from zigpy.zcl.foundation import GENERAL_COMMANDS, GeneralCommand, Status
import zigpy.zdo.types as zdo_t

from tests import common
Expand Down Expand Up @@ -450,7 +450,6 @@ def _mock_dev(
endpoint = device.add_endpoint(epid)
endpoint.device_type = ep[SIG_EP_TYPE]
endpoint.profile_id = ep.get(SIG_EP_PROFILE)
endpoint.request = AsyncMock(return_value=[0])

for cluster_id in ep.get(SIG_EP_INPUT, []):
endpoint.add_input_cluster(cluster_id)
Expand All @@ -463,9 +462,38 @@ def _mock_dev(
else:
device = get_device(device)

async def mock_request(
cluster: zigpy.types.ClusterId,
sequence: zigpy.types.uint8_t,
data: bytes,
expect_reply: bool = True,
command_id: GeneralCommand | zigpy.types.uint8_t = 0x00,
):
# if isinstance(command_id, (int, GeneralCommand)):
# Some commands can't handle default response, and will
# fail with a non list element as first element.
if command_id in (
GeneralCommand.Read_Reporting_Configuration,
GeneralCommand.Write_Attributes,
):
return [[]]

if command_id in GeneralCommand.__members__:
return GENERAL_COMMANDS[GeneralCommand.Default_Response].schema(
command_id=command_id, status=Status.UNSUP_GENERAL_COMMAND
)

return [0]

# add request mock after device creation since quirks may have added endpoints
for epid, endpoint in device.endpoints.items():
if epid:
endpoint.request = AsyncMock(side_effect=mock_request)
else:
endpoint.request = AsyncMock(return_value=[0])

if patch_cluster:
for endpoint in (ep for epid, ep in device.endpoints.items() if epid):
endpoint.request = AsyncMock(return_value=[0])
for cluster in itertools.chain(
endpoint.in_clusters.values(), endpoint.out_clusters.values()
):
Expand Down
9 changes: 4 additions & 5 deletions tests/test_cluster_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,10 @@ async def test_ep_cluster_handlers_configure(cluster_handler) -> None:
mock.patch.dict(endpoint.claimed_cluster_handlers, claimed, clear=True),
mock.patch.dict(endpoint.client_cluster_handlers, client_handlers, clear=True),
):
await endpoint.async_configure()
await endpoint.async_initialize(mock.sentinel.from_cache)
with pytest.raises(ExceptionGroup):
await endpoint.async_configure()
with pytest.raises(ExceptionGroup):
await endpoint.async_initialize(mock.sentinel.from_cache)

for ch in [*claimed.values(), *client_handlers.values()]:
assert ch.async_initialize.call_count == 1
Expand All @@ -847,9 +849,6 @@ async def test_ep_cluster_handlers_configure(cluster_handler) -> None:
assert ch.async_configure.call_count == 1
assert ch.async_configure.await_count == 1

assert ch_3.debug.call_count == 2
assert ch_5.debug.call_count == 2


async def test_poll_control_configure(
poll_control_ch: PollControlClusterHandler,
Expand Down
8 changes: 0 additions & 8 deletions tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,6 @@ async def test_async_add_to_group_remove_from_group(
async def test_async_bind_to_group(
device_joined: Callable[[ZigpyDevice], Awaitable[Device]],
zigpy_device: Callable[..., ZigpyDevice], # pylint: disable=redefined-outer-name
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test async_bind_to_group method."""
zigpy_dev = zigpy_device(with_basic_cluster_handler=True)
Expand All @@ -640,19 +639,12 @@ async def test_async_bind_to_group(
group.group_id,
[ClusterBinding(name="on_off", type=CLUSTER_TYPE_OUT, id=6, endpoint_id=3)],
)
assert (
"0xb79c: Bind_req 00:0d:7f:00:0a:90:69:e8, ep: 3, cluster: 6 to group: 0x1001 completed: [<Status.SUCCESS: 0>]"
in caplog.text
)

await zha_device_remote.async_unbind_from_group(
group.group_id,
[ClusterBinding(name="on_off", type=CLUSTER_TYPE_OUT, id=6, endpoint_id=3)],
)

m1 = "0xb79c: Unbind_req 00:0d:7f:00:0a:90:69:e8, ep: 3, cluster: 6"
assert f"{m1} to group: 0x1001 completed: [<Status.SUCCESS: 0>]" in caplog.text


async def test_device_automation_triggers(
device_joined: Callable[[ZigpyDevice], Awaitable[Device]],
Expand Down
7 changes: 6 additions & 1 deletion zha/application/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
import zigpy.device
import zigpy.endpoint
from zigpy.exceptions import DeliveryError
import zigpy.group
from zigpy.state import State
from zigpy.types.named import EUI64
Expand Down Expand Up @@ -618,7 +619,11 @@
)
# we don't have to do this on a nwk swap
# but we don't have a way to tell currently
await zha_device.async_configure()
try:
await zha_device.async_configure()
except* (TimeoutError, DeliveryError):
zha_device.debug("ignoring error %s during rejoin", exc_info=True)

Check warning on line 625 in zha/application/gateway.py

View check run for this annotation

Codecov / codecov/patch

zha/application/gateway.py#L625

Added line #L625 was not covered by tests

device_info = ExtendedDeviceInfoWithPairingStatus(
pairing_status=DevicePairingStatus.CONFIGURED,
**zha_device.extended_device_info.__dict__,
Expand Down
2 changes: 1 addition & 1 deletion zha/zigbee/cluster_handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def _configure_reporting_status(
event_data: dict[str, dict[str, Any]],
) -> None:
"""Parse configure reporting result."""
if isinstance(res, (Exception, ConfigureReportingResponseRecord)):
if isinstance(res, (Exception, ConfigureReportingResponseRecord, int)):
# assume default response
self.debug(
"attr reporting for '%s' on '%s': %s",
Expand Down
11 changes: 6 additions & 5 deletions zha/zigbee/cluster_handlers/lightlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import zigpy.exceptions
from zigpy.zcl.clusters.lightlink import LightLink
from zigpy.zcl.foundation import GENERAL_COMMANDS, GeneralCommand

from zha.zigbee.cluster_handlers import ClusterHandler, ClusterHandlerStatus, registries

Expand Down Expand Up @@ -34,14 +33,16 @@ async def async_configure(self) -> None:
self.warning("Couldn't get list of groups: %s", str(exc))
return

if isinstance(rsp, GENERAL_COMMANDS[GeneralCommand.Default_Response].schema):
groups = []
else:
if isinstance(
rsp, LightLink.ClientCommandDefs.get_group_identifiers_rsp.schema
):
groups = rsp.group_info_records
else:
groups = []

if groups:
for group in groups:
self.debug("Adding coordinator to 0x%04x group id", group.group_id)
await coordinator.add_to_group(group.group_id)
else:
await coordinator.add_to_group(0x0000, name="Default Lightlink Group")
await coordinator.add_to_group(0x0000, name="Lightlink Group")
22 changes: 13 additions & 9 deletions zha/zigbee/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,16 @@ async def _execute_handler_tasks(
*self.claimed_cluster_handlers.values(),
*self.client_cluster_handlers.values(),
]
tasks = [getattr(ch, func_name)(*args) for ch in cluster_handlers]

faults: list[Exception] = []

async def caller(ch: ClusterHandler):
try:
await getattr(ch, func_name)(*args)
except Exception as outcome:
faults.append(outcome)

tasks = [caller(ch) for ch in cluster_handlers]

gather: Callable[..., Awaitable]

Expand All @@ -201,14 +210,9 @@ async def _execute_handler_tasks(
else:
gather = functools.partial(gather_with_limited_concurrency, max_concurrency)

results = await gather(*tasks, return_exceptions=True)
for cluster_handler, outcome in zip(cluster_handlers, results):
if isinstance(outcome, Exception):
cluster_handler.debug(
"'%s' stage failed: %s", func_name, str(outcome), exc_info=outcome
)
else:
cluster_handler.debug("'%s' stage succeeded", func_name)
await gather(*tasks)
if faults:
raise ExceptionGroup(f"{func_name}: some clusters failed", faults)

def async_new_entity(
self,
Expand Down
Loading