Skip to content

Commit

Permalink
Convert ZCL attributes from strings when writing (#267)
Browse files Browse the repository at this point in the history
* Convert ZCL attributes when writing

* Handle bytes too

* Add some unit tests

* Address review comments
  • Loading branch information
puddly authored Nov 27, 2024
1 parent c494a3e commit 6737b36
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 22 deletions.
86 changes: 84 additions & 2 deletions tests/test_application_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Test zha application helpers."""

from typing import Any

import pytest
from zigpy.device import Device as ZigpyDevice
from zigpy.profiles import zha
from zigpy.zcl.clusters.general import Basic, OnOff
import zigpy.types as t
from zigpy.zcl.clusters.general import Basic, Identify, OnOff
from zigpy.zcl.clusters.security import IasZone

from tests.common import (
Expand All @@ -14,7 +18,12 @@
join_zigpy_device,
)
from zha.application.gateway import Gateway
from zha.application.helpers import async_is_bindable_target, get_matched_clusters
from zha.application.helpers import (
async_is_bindable_target,
convert_to_zcl_values,
convert_zcl_value,
get_matched_clusters,
)

IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"
IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8"
Expand Down Expand Up @@ -105,3 +114,76 @@ async def test_get_matched_clusters(
assert matches[0].target_ep_id == 1

assert not await get_matched_clusters(not_bindable_zha_device, remote_zha_device)


class SomeEnum(t.enum8):
"""Some enum."""

value_1 = 0x12
value_2 = 0x34
value_3 = 0x56


class SomeFlag(t.bitmap8):
"""Some bitmap."""

flag_1 = 0b00000001
flag_2 = 0b00000010
flag_3 = 0b00000100


@pytest.mark.parametrize(
("text", "field_type", "result"),
[
# Bytes
(
"b'Some data\\x00\\x01'",
t.SerializableBytes,
t.SerializableBytes(b"Some data\x00\x01"),
),
(
'b"Some data\\x00\\x01"',
t.SerializableBytes,
t.SerializableBytes(b"Some data\x00\x01"),
),
(
b"Some data\x00\x01".hex(),
t.SerializableBytes,
t.SerializableBytes(b"Some data\x00\x01"),
),
# Enum
("value 1", SomeEnum, SomeEnum.value_1),
("value_1", SomeEnum, SomeEnum.value_1),
("SomeEnum.value_1", SomeEnum, SomeEnum.value_1),
(0x12, SomeEnum, SomeEnum.value_1),
# Flag
("flag 1", SomeFlag, SomeFlag.flag_1),
("flag_1", SomeFlag, SomeFlag.flag_1),
("SomeFlag.flag_1", SomeFlag, SomeFlag.flag_1),
("SomeFlag.flag_1|flag_2", SomeFlag, SomeFlag.flag_1 | SomeFlag.flag_2),
(0b00000001, SomeFlag, SomeFlag.flag_1),
([0b00000001], SomeFlag, SomeFlag.flag_1),
([0b00000001, 0b00000010], SomeFlag, SomeFlag.flag_1 | SomeFlag.flag_2),
(["flag_1", "flag_2"], SomeFlag, SomeFlag.flag_1 | SomeFlag.flag_2),
# Int
(0x1234, t.uint16_t, 0x1234),
("0x1234", t.uint16_t, 0x1234),
("4660", t.uint16_t, 0x1234),
# Some fallthrough type
(1.000, t.Single, t.Single(1.000)),
("1.000", t.Single, t.Single(1.000)),
],
)
def test_convert_zcl_value(text: Any, field_type: Any, result: Any) -> None:
"""Test converting ZCL values."""
assert convert_zcl_value(text, field_type) == result


def test_convert_to_zcl_values() -> None:
"""Test converting ZCL values."""

identify_schema = Identify.ServerCommandDefs.identify.schema
assert convert_to_zcl_values(
fields={"identify_time": "1"},
schema=identify_schema,
) == {"identify_time": 1}
72 changes: 53 additions & 19 deletions zha/application/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

import ast
import asyncio
import binascii
import collections
from collections.abc import Callable
import contextlib
import dataclasses
from dataclasses import dataclass
import datetime
Expand Down Expand Up @@ -126,6 +128,53 @@ async def get_matched_clusters(
return clusters_to_bind


def convert_zcl_value(value: Any, field_type: Any) -> Any:
"""Convert user input to ZCL value."""
if issubclass(field_type, enum.Flag):
if isinstance(value, str):
with contextlib.suppress(ValueError):
value = int(value)

if isinstance(value, int):
value = field_type(value)
elif isinstance(value, str):
# List of flags: `SomeFlag.field1 | field2`
value = [v.strip() for v in value.split(".", 1)[-1].split("|")]

if isinstance(value, list):
new_value = 0

for flag in value:
if isinstance(flag, str):
new_value |= field_type[flag.replace(" ", "_")]
else:
new_value |= flag

value = field_type(new_value)
elif issubclass(field_type, enum.Enum):
value = (
field_type[value.replace(" ", "_").split(".", 1)[-1]]
if isinstance(value, str)
else field_type(value)
)
elif issubclass(field_type, zigpy.types.SerializableBytes):
if value.startswith(("b'", 'b"')):
value = ast.literal_eval(value)
else:
value = bytes.fromhex(value)

value = field_type(value)
elif issubclass(field_type, int):
if isinstance(value, str) and value.startswith("0x"):
value = int(value, 16)

value = field_type(value)
else:
value = field_type(value)

return value


def convert_to_zcl_values(
fields: dict[str, Any], schema: CommandSchema
) -> dict[str, Any]:
Expand All @@ -134,32 +183,17 @@ def convert_to_zcl_values(
for field in schema.fields:
if field.name not in fields:
continue
value = fields[field.name]
if issubclass(field.type, enum.Flag) and isinstance(value, list):
new_value = 0

for flag in value:
if isinstance(flag, str):
new_value |= field.type[flag.replace(" ", "_")]
else:
new_value |= flag
value = fields[field.name]
new_value = converted_fields[field.name] = convert_zcl_value(value, field.type)

value = field.type(new_value)
elif issubclass(field.type, enum.Enum):
value = (
field.type[value.replace(" ", "_")]
if isinstance(value, str)
else field.type(value)
)
else:
value = field.type(value)
_LOGGER.debug(
"Converted ZCL schema field(%s) value from: %s to: %s",
field.name,
fields[field.name],
value,
new_value,
)
converted_fields[field.name] = value

return converted_fields


Expand Down
5 changes: 4 additions & 1 deletion zha/zigbee/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
ZHA_CLUSTER_HANDLER_MSG,
ZHA_EVENT,
)
from zha.application.helpers import convert_to_zcl_values
from zha.application.helpers import convert_to_zcl_values, convert_zcl_value
from zha.application.platforms import BaseEntityInfo, PlatformEntity
from zha.event import EventBase
from zha.exceptions import ZHAException
Expand Down Expand Up @@ -874,6 +874,9 @@ async def write_zigbee_attribute(
f" writing attribute {attribute} with value {value}"
) from exc

attr_def = cluster.find_attribute(attribute)
value = convert_zcl_value(value, attr_def.type)

try:
response = await cluster.write_attributes(
{attribute: value}, manufacturer=manufacturer
Expand Down

0 comments on commit 6737b36

Please sign in to comment.