Skip to content

Commit

Permalink
Rough in client ipv6 support
Browse files Browse the repository at this point in the history
For #12.
  • Loading branch information
mnot committed May 27, 2022
1 parent ea4a0bd commit f202947
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 50 deletions.
30 changes: 15 additions & 15 deletions test/test_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class TestDns(unittest.TestCase):

def setUp(self):
self.loop = loop.make()
self.loop.schedule(5, self.timeout)
Expand All @@ -28,25 +27,26 @@ def check_gai_error(self, results):
self.assertTrue(isinstance(results, socket.gaierror), results)

def test_basic(self):
lookup(b'www.google.com', self.check_success)
lookup(b"www.google.com", 80, socket.SOCK_STREAM, self.check_success)
self.loop.run()

def test_lots(self):
lookup(b'www.google.com', self.check_success)
lookup(b'www.facebook.com', self.check_success)
lookup(b'www.example.com', self.check_success)
lookup(b'www.ietf.org', self.check_success)
lookup(b'www.github.com', self.check_success)
lookup(b'www.twitter.com', self.check_success)
lookup(b'www.abc.net.au', self.check_success)
lookup(b'www.mnot.net', self.check_success)
lookup(b'www.eff.org', self.check_success)
lookup(b'www.aclu.org', self.check_success)
lookup(b"www.google.com", 443, socket.SOCK_STREAM, self.check_success)
lookup(b"www.facebook.com", 80, socket.SOCK_STREAM, self.check_success)
lookup(b"www.example.com", 80, socket.SOCK_STREAM, self.check_success)
lookup(b"www.ietf.org", 443, socket.SOCK_STREAM, self.check_success)
lookup(b"www.github.com", 443, socket.SOCK_STREAM, self.check_success)
lookup(b"www.twitter.com", 443, socket.SOCK_STREAM, self.check_success)
lookup(b"www.abc.net.au", 80, socket.SOCK_STREAM, self.check_success)
lookup(b"www.mnot.net", 443, socket.SOCK_STREAM, self.check_success)
lookup(b"www.eff.org", 443, socket.SOCK_STREAM, self.check_success)
lookup(b"www.aclu.org", 443, socket.SOCK_STREAM, self.check_success)
self.loop.run()

def test_gai(self):
lookup(b'foo.foo', self.check_gai_error)
lookup(b'bar.bar', self.check_gai_error)
lookup(b"foo.foo", 23, socket.SOCK_STREAM, self.check_gai_error)
lookup(b"bar.bar", 23, socket.SOCK_DGRAM, self.check_gai_error)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
25 changes: 20 additions & 5 deletions thor/dns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
#!/usr/bin/env python

from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
import socket
from typing import Callable, Union
from typing import Callable, Union, Tuple, List

pool_size = 10

DnsResult = Tuple[socket.AddressFamily, socket.SocketKind, int, str, Tuple[str, int]]
DnsResultList = List[DnsResult]

def lookup(host: bytes, cb: Callable[..., None]) -> None:
f = _pool.submit(_lookup, host)

def lookup(host: bytes, port: int, proto: int, cb: Callable[..., None]) -> None:
f = _pool.submit(_lookup, host, port, proto)

def done(ff: Future) -> None:
cb(ff.result())

f.add_done_callback(done)


def _lookup(host: bytes) -> Union[str, Exception]:
def _lookup(host: bytes, port: int, socktype: int) -> Union[DnsResultList, Exception]:
try:
return socket.gethostbyname(host.decode("idna"))
return socket.getaddrinfo(host, port, type=socktype) # type: ignore
except Exception as why:
return why


def pickDnsResult(results: DnsResultList) -> DnsResult:
table = defaultdict(list)
for result in results:
table[result[0]].append(result)

if socket.has_ipv6 and socket.AF_INET6 in table:
return table[socket.AF_INET6][0]
else:
return table[socket.AF_INET][0]


_pool = ThreadPoolExecutor(max_workers=pool_size)
60 changes: 43 additions & 17 deletions thor/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,24 @@ def _parse_uri(self, uri: bytes) -> OriginType:
raise ValueError
if b"@" in authority:
authority = authority.split(b"@", 1)[1]
if b":" in authority:
portb = None
ipv6_literal = False
if authority.startswith(b"["):
ipv6_literal = True
try:
delimiter = authority.index(b"]")
except ValueError:
self.input_error(UrlError("IPv6 URL missing ]"), False)
raise ValueError
hostb = authority[1:delimiter]
rest = authority[delimiter + 1 :]
if rest.startswith(b":"):
portb = rest[1:]
elif b":" in authority:
hostb, portb = authority.rsplit(b":", 1)
else:
hostb = authority
if portb:
try:
port = int(portb.decode("utf-8", "replace"))
except ValueError:
Expand All @@ -293,32 +309,42 @@ def _parse_uri(self, uri: bytes) -> OriginType:
self.input_error(UrlError("URL port %i out of range" % port), False)
raise ValueError
else:
hostb, port = authority, default_port
port = default_port
try:
host = hostb.decode("ascii", "strict")
except UnicodeDecodeError:
self.input_error(UrlError("URL host has non-ascii characters"), False)
raise ValueError
if not all(c in ascii_letters + digits + ".-" for c in host):
self.input_error(UrlError("URL hostname has disallowed character"), False)
raise ValueError
if ipv6_literal:
print(host)
if not all(c in digits + ":abcdefABCDEF" for c in host):
self.input_error(
UrlError("URL IPv6 literal has disallowed character"), False
)
raise ValueError
else:
if not all(c in ascii_letters + digits + ".-" for c in host):
self.input_error(
UrlError("URL hostname has disallowed character"), False
)
raise ValueError
labels = host.split(".")
if any(len(l) == 0 for l in labels):
self.input_error(UrlError("URL hostname has empty label"), False)
raise ValueError
if any(len(l) > 63 for l in labels):
self.input_error(
UrlError("URL hostname label greater than 63 characters"), False
)
raise ValueError
# if any(l[0].isdigit() for l in labels):
# self.input_error(UrlError("URL hostname label starts with digit"), False)
# raise ValueError
if len(host) > 255:
self.input_error(
UrlError("URL hostname greater than 255 characters"), False
)
raise ValueError
labels = host.split(".")
if any(len(l) == 0 for l in labels):
self.input_error(UrlError("URL hostname has empty label"), False)
raise ValueError
if any(len(l) > 63 for l in labels):
self.input_error(
UrlError("URL hostname label greater than 63 characters"), False
)
raise ValueError
# if any(l[0].isdigit() for l in labels):
# self.input_error(UrlError("URL hostname label starts with digit"), False)
# raise ValueError
if path == b"":
path = b"/"
self.authority = authority
Expand Down
17 changes: 9 additions & 8 deletions thor/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Tuple, List, Union, Type, Callable # pylint: disable=unused-import
import ssl as sys_ssl # pylint: disable=unused-import

from thor.dns import lookup
from thor.dns import lookup, pickDnsResult, DnsResultList
from thor.loop import EventSource, LoopBase, schedule
from thor.loop import ScheduledEvent # pylint: disable=unused-import

Expand Down Expand Up @@ -318,28 +318,29 @@ def connect(self, host: bytes, port: int, connect_timeout: float = None) -> None
self.handle_socket_error,
socket.error(errno.ETIMEDOUT, os.strerror(errno.ETIMEDOUT)),
)
lookup(host, self._continue_connect)
lookup(host, port, socket.SOCK_STREAM, self._continue_connect)

def _continue_connect(self, dns_result: Union[str, Exception]) -> None:
def _continue_connect(self, dns_results: Union[DnsResultList, Exception]) -> None:
"""
Continue connecting after DNS returns a result.
"""
if isinstance(dns_result, Exception):
self.handle_socket_error(dns_result, "gai")
if isinstance(dns_results, Exception):
self.handle_socket_error(dns_results, "gai")
return
dns_result = pickDnsResult(dns_results)
if self.check_ip is not None:
if not self.check_ip(dns_result):
if not self.check_ip(dns_result[4][0]):
self.handle_conn_error("access", 0, "IP Check failed")
return

self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock = socket.socket(dns_result[0], socket.SOCK_STREAM)
self.sock.setblocking(False)
self.once("fd_error", self.handle_fd_error)
self.register_fd(self.sock.fileno(), "fd_writable")
self.event_add("fd_error")
self.once("fd_writable", self.handle_connect)
try:
err = self.sock.connect_ex((dns_result, self.port))
err = self.sock.connect_ex(dns_result[4])
except socket.error as why:
self.handle_socket_error(why)
return
Expand Down
26 changes: 21 additions & 5 deletions thor/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import socket
from typing import Union

from thor.dns import lookup
from thor.dns import lookup, pickDnsResult, DnsResultList
from thor.loop import EventSource, LoopBase


Expand All @@ -35,10 +35,11 @@ def __init__(self, loop: LoopBase = None) -> None:
EventSource.__init__(self, loop)
self.host = None # type: bytes
self.port = None # type: int
self._error_sent = False
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.setblocking(False)
self.max_dgram = min(
(2 ** 16 - 40), self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)
(2**16 - 40), self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)
)
self.on("fd_readable", self.handle_datagram)
self.register_fd(self.sock.fileno())
Expand All @@ -58,10 +59,14 @@ def bind(self, host: bytes, port: int) -> None:
self.host = host
self.port = port
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
lookup(host, self._continue_bind)
lookup(host, port, socket.SOCK_DGRAM, self._continue_bind)

def _continue_bind(self, dns_result: Union[str, Exception]) -> None:
self.sock.bind((dns_result, self.port))
def _continue_bind(self, dns_results: Union[DnsResultList, Exception]) -> None:
if isinstance(dns_results, Exception):
self.handle_socket_error(dns_results, "gai")
return
dns_result = pickDnsResult(dns_results)
self.sock.bind(dns_result[4])

def shutdown(self) -> None:
"Close the listening socket."
Expand Down Expand Up @@ -99,3 +104,14 @@ def handle_datagram(self) -> None:
else:
raise
self.emit("datagram", data, addr[0], addr[1])

def handle_socket_error(self, why: Exception, err_type: str = "socket") -> None:
err_id = why.args[0]
err_str = why.args[1]
if self._error_sent:
return
self._error_sent = True
self.unregister_fd()
self.emit("socket_error", err_type, err_id, err_str)
if self.sock:
self.sock.close()

0 comments on commit f202947

Please sign in to comment.