Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add utils #23

Merged
merged 4 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ socket = await start_connection(addr_infos, local_addr_infos=local_addr_infos, h
transport, protocol = await loop.create_connection(
MyProtocol, sock=socket, ...)

# Remove the first address for each family from addr_info
pop_addr_infos_interleave(addr_info, 1)

# Remove all matching address from addr_info
remove_addr_infos(addr_info, "dead::beef::")

```

## Credits
Expand Down
6 changes: 6 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,10 @@ socket = await aiohappyeyeballs.start_connection(addr_infos, local_addr_infos=lo

transport, protocol = await loop.create_connection(
MyProtocol, sock=socket, ...)

# Remove the first address for each family from addr_info
aiohappyeyeballs.pop_addr_infos_interleave(addr_info, 1)

# Remove all matching address from addr_info
aiohappyeyeballs.remove_addr_infos(addr_info, "dead::beef::")
```
11 changes: 9 additions & 2 deletions src/aiohappyeyeballs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
__version__ = "1.7.0"

from .impl import AddrInfoType, start_connection
from .impl import start_connection
from .types import AddrInfoType
from .utils import pop_addr_infos_interleave, remove_addr_infos

__all__ = ("start_connection", "AddrInfoType")
__all__ = (
"start_connection",
"AddrInfoType",
"remove_addr_infos",
"pop_addr_infos_interleave",
)
11 changes: 11 additions & 0 deletions src/aiohappyeyeballs/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Base implementation."""
import socket
from typing import Tuple, Union

AddrInfoType = Tuple[
Union[int, socket.AddressFamily],
Union[int, socket.SocketKind],
int,
str,
Tuple, # type: ignore[type-arg]
]
51 changes: 51 additions & 0 deletions src/aiohappyeyeballs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Utility functions for aiohappyeyeballs."""

import ipaddress
from typing import Dict, List

from .types import AddrInfoType


def pop_addr_infos_interleave(addr_infos: List[AddrInfoType], interleave: int) -> None:
"""
Pop addr_info from the list of addr_infos by family up to interleave times.

The interleave parameter is used to know how many addr_infos for
each family should be popped of the top of the list.
"""
seen: Dict[int, int] = {}
to_remove: List[AddrInfoType] = []
for addr_info in addr_infos:
family = addr_info[0]
if family not in seen:
seen[family] = 0
if seen[family] < interleave:
to_remove.append(addr_info)
seen[family] += 1
for addr_info in to_remove:
addr_infos.remove(addr_info)


def remove_addr_infos(
addr_infos: List[AddrInfoType],
address: str,
) -> None:
"""Remove an address from the list of addr_infos."""
bad_addrs_infos: List[AddrInfoType] = []
for addr_info in addr_infos:
if addr_info[-1][0] == address:
bad_addrs_infos.append(addr_info)
if bad_addrs_infos:
for bad_addr_info in bad_addrs_infos:
addr_infos.remove(bad_addr_info)
return
# Slow path in case addr is formatted differently
ip_address = ipaddress.ip_address(address)
for addr_info in addr_infos:
if ip_address == ipaddress.ip_address(addr_info[-1][0]):
bad_addrs_infos.append(addr_info)
if bad_addrs_infos:
for bad_addr_info in bad_addrs_infos:
addr_infos.remove(bad_addr_info)
return
raise ValueError(f"Address {address} not found in addr_infos")
107 changes: 107 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import socket
from typing import List

import pytest

from aiohappyeyeballs import AddrInfoType, pop_addr_infos_interleave, remove_addr_infos


def test_pop_addr_infos_interleave():
"""Test pop_addr_infos_interleave."""
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
addr_info_copy = addr_info.copy()
pop_addr_infos_interleave(addr_info_copy, 1)
assert addr_info_copy == [ipv6_addr_info_2]
pop_addr_infos_interleave(addr_info_copy, 1)
assert addr_info_copy == []
addr_info_copy = addr_info.copy()
pop_addr_infos_interleave(addr_info_copy, 2)
assert addr_info_copy == []


def test_remove_addr_infos():
"""Test remove_addr_infos."""
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
addr_info_copy = addr_info.copy()
remove_addr_infos(addr_info_copy, "dead:beef::")
assert addr_info_copy == [ipv6_addr_info_2, ipv4_addr_info]
remove_addr_infos(addr_info_copy, "dead:aaaa::")
assert addr_info_copy == [ipv4_addr_info]
remove_addr_infos(addr_info_copy, "107.6.106.83")
assert addr_info_copy == []


def test_remove_addr_infos_slow_path():
"""Test remove_addr_infos with mis-matched formatting."""
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
addr_info_copy = addr_info.copy()
remove_addr_infos(addr_info_copy, "dead:beef:0000:0000:0000:0000:0000:0000")
assert addr_info_copy == [ipv6_addr_info_2, ipv4_addr_info]
remove_addr_infos(addr_info_copy, "dead:aaaa:0000:0000:0000:0000:0000:0000")
assert addr_info_copy == [ipv4_addr_info]
with pytest.raises(ValueError, match="Address 107.6.106.2 not found in addr_infos"):
remove_addr_infos(addr_info_copy, "107.6.106.2")
assert addr_info_copy == [ipv4_addr_info]
Loading