Skip to content

Commit

Permalink
Merge pull request #15 from bdraco/use_system_endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Jan 18, 2022
2 parents 7da4e33 + 7d38cd4 commit 9e1d951
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 30 deletions.
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pytest-cov = "^3.0.0"
pyupgrade = "^2.29.1"
tox = "^3.20.1"
pytest-asyncio = "^0.17.2"
aioresponses = "^0.7.3"

[tool.semantic_release]
branch = "main"
Expand Down
101 changes: 98 additions & 3 deletions tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,25 @@
from unittest.mock import MagicMock, patch

import pytest
from aioresponses import aioresponses

from unifi_discovery import (
DISCOVERY_PORT,
UBNT_REQUEST_PAYLOAD,
AIOUnifiScanner,
UnifiDevice,
UnifiDiscovery,
UnifiService,
create_udp_socket,
)


@pytest.fixture
def mock_aioresponse():
with aioresponses() as m:
yield m


@pytest.fixture
async def mock_discovery_aio_protocol():
"""Fixture to mock an asyncio connection."""
Expand All @@ -39,10 +47,11 @@ async def _mock_create_datagram_endpoint(func, sock=None):


@pytest.mark.asyncio
async def test_async_scanner_specific_address(mock_discovery_aio_protocol):
async def test_async_scanner_specific_address(
mock_discovery_aio_protocol, mock_aioresponse
):
"""Test scanner with a specific address."""
scanner = AIOUnifiScanner()

task = asyncio.ensure_future(
scanner.async_scan(timeout=10, address="192.168.212.1")
)
Expand Down Expand Up @@ -70,9 +79,87 @@ async def test_async_scanner_specific_address(mock_discovery_aio_protocol):


@pytest.mark.asyncio
async def test_async_scanner_broadcast(mock_discovery_aio_protocol):
async def test_async_scanner_broadcast(mock_discovery_aio_protocol, mock_aioresponse):
"""Test scanner with a broadcast."""
scanner = AIOUnifiScanner()
mock_aioresponse.get("https://192.168.203.1/proxy/protect/api", status=401)
mock_aioresponse.get(
"https://192.168.203.1/api/system",
payload={
"hardware": {"shortname": "UDMPROSE"},
"name": "UDM Pro SE",
"mac": "245A4CDD6616",
"isSingleUser": True,
"isSsoEnabled": True,
"directConnectDomain": "xyz.id.ui.direct",
},
)

task = asyncio.ensure_future(scanner.async_scan(timeout=0.01))
_, protocol = await mock_discovery_aio_protocol()
protocol.datagram_received(
UBNT_REQUEST_PAYLOAD,
("192.168.203.1", DISCOVERY_PORT),
)
protocol.datagram_received(
b"",
("127.0.0.1", DISCOVERY_PORT),
)
protocol.datagram_received(
None,
("127.0.0.1", DISCOVERY_PORT),
)
protocol.datagram_received(
b"\x01\x00\x00\xa5\x01\x00\x06$ZLu\xba\xe6\x02\x00\n$ZLu\xba\xe6\xc0\xa8\xd5/\x03\x001UFP-UAP-B.MT7622_SOC.v0.4.0.4.340d302.220106.0349\x04\x00\x04\xc0\xa8\xd5/\x05\x00\x06$ZLu\xba\xe6\n\x00\x04\x00\x0c\xda/\x0b\x00\x11AlexanderTechRoom\x0c\x00\tUFP-UAP-B\x10\x00\x02\xa6 \x14\x00\x18Unifi-Protect-UAP-Bridge\x17\x00\x01\x00",
("192.168.213.252", DISCOVERY_PORT),
)
await task
assert scanner.found_devices == [
UnifiDevice(
source_ip="192.168.203.1",
hw_addr="24:5a:4c:dd:66:16",
ip_info=None,
addr_entry=None,
fw_version=None,
mac_address=None,
uptime=None,
hostname="UDM-Pro-SE",
platform="UDMPROSE",
model=None,
signature_version="1",
services={UnifiService.Protect: True},
direct_connect_domain="xyz.id.ui.direct",
is_sso_enabled=True,
is_single_user=True,
),
UnifiDevice(
source_ip="192.168.213.252",
hw_addr="24:5a:4c:75:ba:e6",
ip_info=["24:5a:4c:75:ba:e6;192.168.213.47"],
addr_entry="192.168.213.47",
fw_version="UFP-UAP-B.MT7622_SOC.v0.4.0.4.340d302.220106.0349",
mac_address="24:5a:4c:75:ba:e6",
uptime=842287,
hostname="AlexanderTechRoom",
platform="UFP-UAP-B",
model="Unifi-Protect-UAP-Bridge",
signature_version="1",
services={UnifiService.Protect: False},
direct_connect_domain=None,
is_sso_enabled=None,
is_single_user=None,
),
]


@pytest.mark.asyncio
async def test_async_scanner_no_system_response(
mock_discovery_aio_protocol, mock_aioresponse
):
"""Test scanner with a broadcast when the system api does not response."""
scanner = AIOUnifiScanner()
mock_aioresponse.get("https://192.168.203.1/proxy/protect/api", status=401)
mock_aioresponse.get("https://192.168.203.1/api/system", status=404)

task = asyncio.ensure_future(scanner.async_scan(timeout=0.01))
_, protocol = await mock_discovery_aio_protocol()
Expand Down Expand Up @@ -106,6 +193,10 @@ async def test_async_scanner_broadcast(mock_discovery_aio_protocol):
platform=None,
model=None,
signature_version="1",
services={UnifiService.Protect: True},
direct_connect_domain=None,
is_sso_enabled=None,
is_single_user=None,
),
UnifiDevice(
source_ip="192.168.213.252",
Expand All @@ -119,6 +210,10 @@ async def test_async_scanner_broadcast(mock_discovery_aio_protocol):
platform="UFP-UAP-B",
model="Unifi-Protect-UAP-Bridge",
signature_version="1",
services={UnifiService.Protect: False},
direct_connect_domain=None,
is_sso_enabled=None,
is_single_user=None,
),
]

Expand Down
94 changes: 68 additions & 26 deletions unifi_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
import socket
import time
from contextlib import suppress
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from enum import Enum, auto
from http import HTTPStatus
from ipaddress import ip_address, ip_network
from struct import unpack
from typing import TYPE_CHECKING, Awaitable, Callable, cast

from aiohttp import ClientResponse, ClientSession, ClientTimeout, TCPConnector
from aiohttp import (
ClientError,
ClientResponse,
ClientSession,
ClientTimeout,
TCPConnector,
)

if TYPE_CHECKING:
from pyroute2 import IPRoute # type: ignore
Expand All @@ -35,7 +41,7 @@ class UnifiService(Enum):
)


PROBE_PLATFORMS = {"UDMPROSE", "UDMPRO", "UNVR", "UNVR", None}
PROBE_PLATFORMS = {"UDMPROSE", "UDMPRO", "UNVR", "UNVRPRO", "UCKP", None}

# UBNT discovery packet payload and reply signature
UBNT_REQUEST_PAYLOAD = b"\x01\x00\x00\x00"
Expand All @@ -56,6 +62,10 @@ def mac_repr(data):
return ":".join(("%02x" % b) for b in data)


def _format_mac(mac: str) -> str:
return ":".join(mac.lower()[i : i + 2] for i in range(0, 12, 2))


def ip_repr(data):
return ".".join(("%d" % b) for b in data)

Expand Down Expand Up @@ -116,6 +126,9 @@ class UnifiDevice:
model: str | None = None
signature_version: str | None = None
services: dict[UnifiService, bool] = field(default_factory=_services_dict)
direct_connect_domain: str | None = None
is_sso_enabled: bool | None = None
is_single_user: bool | None = None


def async_get_source_ip(target_ip: str) -> str | None:
Expand Down Expand Up @@ -363,39 +376,68 @@ async def _add_missing_hw_addresses(
self, response_list: dict[str, UnifiDevice]
) -> None:
"""Add any missing hardware addresses to the response list."""
if any(device.hw_addr is None for device in response_list.values()):
arp = ArpSearch()
neighbors = await arp.async_get_neighbors()
for source, device in response_list.items():
if device.hw_addr is None and device.source_ip in neighbors:
response_list[source] = UnifiDevice(
source_ip=device.source_ip,
hw_addr=neighbors[device.source_ip],
ip_info=[f"{neighbors[device.source_ip]};{device.source_ip}"],
services=device.services,
)
if not any(device.hw_addr is None for device in response_list.values()):
return
arp = ArpSearch()
neighbors = await arp.async_get_neighbors()
for source, device in response_list.items():
if device.hw_addr is None and device.source_ip in neighbors:
response_list[source] = replace(
device,
hw_addr=neighbors[device.source_ip],
ip_info=[f"{neighbors[device.source_ip]};{device.source_ip}"],
)

async def _probe_services(self, response_list: dict[str, UnifiDevice]) -> None:
async def _probe_services_and_system(
self, response_list: dict[str, UnifiDevice]
) -> None:
"""Check which services are available and update the services dict."""
timeout = ClientTimeout(total=5.0)
async with ClientSession(
connector=TCPConnector(verify_ssl=False), timeout=timeout
connector=TCPConnector(ssl=False), timeout=timeout
) as s:
device_tasks: dict[str, Awaitable] = {
device.source_ip: s.get(f"https://{device.source_ip}/proxy/protect/api")
for device in response_list.values()
if device.platform in PROBE_PLATFORMS
}
device_tasks: dict[str, Awaitable] = {}
system_tasks: dict[str, Awaitable] = {}
for device in response_list.values():
if device.platform in PROBE_PLATFORMS:
source_ip = device.source_ip
device_tasks[source_ip] = s.get(
f"https://{source_ip}/proxy/protect/api"
)
system_tasks[source_ip] = s.get(f"https://{source_ip}/api/system")
results: list[ClientResponse | Exception] = await asyncio.gather(
*device_tasks.values(), return_exceptions=True
*(*device_tasks.values(), *system_tasks.values()),
return_exceptions=True,
)
device_task_len = len(device_tasks)
for idx, source_ip in enumerate(device_tasks):
response = results[idx]
device_response = results[idx]
response_list[source_ip].services[UnifiService.Protect] = (
response.status == HTTPStatus.UNAUTHORIZED
if not isinstance(response, Exception)
device_response.status == HTTPStatus.UNAUTHORIZED
if not isinstance(device_response, Exception)
else False
)
system_response = results[idx + device_task_len]
if isinstance(system_response, Exception):
continue
try:
system = await system_response.json()
except (asyncio.TimeoutError, ClientError):
_LOGGER.exception("Failed to get system info for %s", source_ip)
continue
if not system:
continue
device = response_list[source_ip]
short_name = system.get("hardware", {}).get("shortname")
response_list[source_ip] = replace(
device,
platform=device.platform or short_name,
hostname=device.hostname or system.get("name").replace(" ", "-"),
hw_addr=device.hw_addr or _format_mac(system.get("mac")),
direct_connect_domain=system.get("directConnectDomain"),
is_sso_enabled=system.get("isSsoEnabled"),
is_single_user=system.get("isSingleUser"),
)

async def async_scan(
self, timeout: int = 31, address: str | None = None
Expand Down Expand Up @@ -428,7 +470,7 @@ def _on_response(data: bytes, addr: tuple[str, int]) -> None:
finally:
transport.close()

await self._probe_services(response_list)
await self._probe_services_and_system(response_list)
await self._add_missing_hw_addresses(response_list)

self.found_devices = list(response_list.values())
Expand Down

0 comments on commit 9e1d951

Please sign in to comment.