Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanly shut down the serial port on disconnect #259

Merged
merged 11 commits into from
Oct 27, 2024
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ license = {text = "GPL-3.0"}
requires-python = ">=3.8"
dependencies = [
"voluptuous",
"zigpy>=0.68.0",
"zigpy>=0.70.0",
'async-timeout; python_version<"3.11"',
]

Expand Down Expand Up @@ -47,6 +47,7 @@ ignore_errors = true

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"

[tool.flake8]
exclude = [".venv", ".git", ".tox", "docs", "venv", "bin", "lib", "deps", "build"]
Expand Down
44 changes: 34 additions & 10 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,23 @@


@pytest.fixture
def gateway():
async def gateway():
return uart.Gateway(api=None)


@pytest.fixture
def api(gateway, mock_command_rsp):
async def api(gateway, mock_command_rsp):
loop = asyncio.get_running_loop()

async def mock_connect(config, api):
transport = MagicMock()
transport.close = MagicMock(
side_effect=lambda: loop.call_soon(gateway.connection_lost, None)
)

gateway._api = api
gateway.connection_made(MagicMock())
gateway.connection_made(transport)

return gateway

with patch("zigpy_deconz.uart.connect", side_effect=mock_connect):
Expand Down Expand Up @@ -178,15 +186,33 @@ async def test_connect(api, mock_command_rsp):
await api.connect()


async def test_connect_failure(api, mock_command_rsp):
transport = None

def mock_version(*args, **kwargs):
nonlocal transport
transport = api._uart._transport

raise asyncio.TimeoutError()

with patch.object(api, "version", side_effect=mock_version):
# We connect but fail to probe
with pytest.raises(asyncio.TimeoutError):
await api.connect()

assert api._uart is None
assert len(transport.close.mock_calls) == 1


async def test_close(api):
await api.connect()

uart = api._uart
uart.close = MagicMock(wraps=uart.close)
uart.disconnect = AsyncMock()

api.close()
await api.disconnect()
assert api._uart is None
assert uart.close.call_count == 1
assert uart.disconnect.call_count == 1


def test_commands():
Expand Down Expand Up @@ -898,11 +924,9 @@ async def test_data_poller(api, mock_command_rsp):

# The task is cancelled on close
task = api._data_poller_task
api.close()
await api.disconnect()
assert api._data_poller_task is None

if sys.version_info >= (3, 11):
assert task.cancelling()
assert task.done()


async def test_get_device_state(api, mock_command_rsp):
Expand Down
7 changes: 4 additions & 3 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ async def test_connect_failure(app):
with patch.object(application, "Deconz") as api_mock:
api = api_mock.return_value = MagicMock()
api.connect = AsyncMock(side_effect=RuntimeError("Broken"))
api.disconnect = AsyncMock()

app._api = None

Expand All @@ -195,16 +196,16 @@ async def test_connect_failure(app):

assert app._api is None
api.connect.assert_called_once()
api.close.assert_called_once()
api.disconnect.assert_called_once()


async def test_disconnect(app):
api_close = app._api.close = MagicMock()
api_disconnect = app._api.disconnect = AsyncMock()

await app.disconnect()

assert app._api is None
assert api_close.call_count == 1
assert api_disconnect.call_count == 1


async def test_disconnect_no_api(app):
Expand Down
13 changes: 11 additions & 2 deletions tests/test_uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from unittest import mock

import pytest
from zigpy.config import CONF_DEVICE_BAUDRATE, CONF_DEVICE_PATH
from zigpy.config import (
CONF_DEVICE_BAUDRATE,
CONF_DEVICE_FLOW_CONTROL,
CONF_DEVICE_PATH,
)
import zigpy.serial

from zigpy_deconz import uart
Expand All @@ -28,7 +32,12 @@ async def mock_conn(loop, protocol_factory, **kwargs):
monkeypatch.setattr(zigpy.serial, "create_serial_connection", mock_conn)

await uart.connect(
{CONF_DEVICE_PATH: "/dev/null", CONF_DEVICE_BAUDRATE: 115200}, api
{
CONF_DEVICE_PATH: "/dev/null",
CONF_DEVICE_BAUDRATE: 115200,
CONF_DEVICE_FLOW_CONTROL: None,
},
api,
)


Expand Down
27 changes: 13 additions & 14 deletions zigpy_deconz/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
else:
from asyncio import timeout as asyncio_timeout # pragma: no cover

from zigpy.config import CONF_DEVICE_PATH
from zigpy.datastructures import PriorityLock
from zigpy.types import (
APSStatus,
Expand Down Expand Up @@ -461,37 +460,37 @@ def protocol_version(self) -> int:

async def connect(self) -> None:
assert self._uart is None

self._uart = await zigpy_deconz.uart.connect(self._config, self)

await self.version()
try:
await self.version()
device_state_rsp = await self.send_command(CommandId.device_state)
except Exception:
await self.disconnect()
self._uart = None
raise

device_state_rsp = await self.send_command(CommandId.device_state)
self._device_state = device_state_rsp["device_state"]

self._data_poller_task = asyncio.create_task(self._data_poller())

def connection_lost(self, exc: Exception) -> None:
def connection_lost(self, exc: Exception | None) -> None:
"""Lost serial connection."""
LOGGER.debug(
"Serial %r connection lost unexpectedly: %r",
self._config[CONF_DEVICE_PATH],
exc,
)

if self._app is not None:
self._app.connection_lost(exc)

def close(self):
self._app = None

async def disconnect(self):
if self._data_poller_task is not None:
self._data_poller_task.cancel()
self._data_poller_task = None

if self._uart is not None:
self._uart.close()
await self._uart.disconnect()
self._uart = None

self._app = None

def _get_command_priority(self, command: Command) -> int:
return {
# The watchdog is fed using `write_parameter` and `get_device_state` so they
Expand Down
55 changes: 21 additions & 34 deletions zigpy_deconz/uart.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,50 @@
"""Uart module."""

from __future__ import annotations

import asyncio
import binascii
import logging
from typing import Callable, Dict
from typing import Any, Callable

import zigpy.config
import zigpy.serial

LOGGER = logging.getLogger(__name__)


class Gateway(asyncio.Protocol):
class Gateway(zigpy.serial.SerialProtocol):
END = b"\xC0"
ESC = b"\xDB"
ESC_END = b"\xDC"
ESC_ESC = b"\xDD"

def __init__(self, api, connected_future=None):
def __init__(self, api):
"""Initialize instance of the UART gateway."""

super().__init__()
self._api = api
self._buffer = b""
self._connected_future = connected_future
self._transport = None

def connection_lost(self, exc) -> None:
def connection_lost(self, exc: Exception | None) -> None:
"""Port was closed expectedly or unexpectedly."""
super().connection_lost(exc)

if exc is not None:
LOGGER.warning("Lost connection: %r", exc, exc_info=exc)

self._api.connection_lost(exc)

def connection_made(self, transport):
"""Call this when the uart connection is established."""

LOGGER.debug("Connection made")
self._transport = transport
if self._connected_future and not self._connected_future.done():
self._connected_future.set_result(True)
if self._api is not None:
self._api.connection_lost(exc)

def close(self):
self._transport.close()
super().close()
self._api = None

def send(self, data):
def send(self, data: bytes) -> None:
"""Send data, taking care of escaping and framing."""
LOGGER.debug("Send: %s", binascii.hexlify(data).decode())
checksum = bytes(self._checksum(data))
frame = self._escape(data + checksum)
self._transport.write(self.END + frame + self.END)

def data_received(self, data):
def data_received(self, data: bytes) -> None:
"""Handle data received from the uart."""
self._buffer += data
super().data_received(data)

while self._buffer:
end = self._buffer.find(self.END)
if end < 0:
Expand Down Expand Up @@ -121,23 +112,19 @@ def _checksum(self, data):
return bytes(ret)


async def connect(config: Dict[str, any], api: Callable) -> Gateway:
loop = asyncio.get_running_loop()
connected_future = loop.create_future()
protocol = Gateway(api, connected_future)
async def connect(config: dict[str, Any], api: Callable) -> Gateway:
protocol = Gateway(api)

LOGGER.debug("Connecting to %s", config[zigpy.config.CONF_DEVICE_PATH])

_, protocol = await zigpy.serial.create_serial_connection(
loop=loop,
loop=asyncio.get_running_loop(),
protocol_factory=lambda: protocol,
url=config[zigpy.config.CONF_DEVICE_PATH],
baudrate=config[zigpy.config.CONF_DEVICE_BAUDRATE],
xonxoff=False,
flow_control=config[zigpy.config.CONF_DEVICE_FLOW_CONTROL],
)

await connected_future

LOGGER.debug("Connected to %s", config[zigpy.config.CONF_DEVICE_PATH])
await protocol.wait_until_connected()

return protocol
4 changes: 2 additions & 2 deletions zigpy_deconz/zigbee/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def connect(self):
try:
await api.connect()
except Exception:
api.close()
await api.disconnect()
raise

self._api = api
Expand All @@ -109,7 +109,7 @@ async def disconnect(self):
self._delayed_neighbor_scan_task = None

if self._api is not None:
self._api.close()
await self._api.disconnect()
self._api = None

async def permit_with_link_key(self, node: t.EUI64, link_key: t.KeyData, time_s=60):
Expand Down