Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close all and counts #13

Merged
merged 5 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 115 additions & 83 deletions adafruit_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


if not sys.implementation.name == "circuitpython":
from typing import Optional, Tuple
from typing import List, Optional, Tuple

from circuitpython_typing.socket import (
CircuitPythonSocketType,
Expand Down Expand Up @@ -71,8 +71,7 @@ class _FakeSSLContext:
def __init__(self, iface: InterfaceType) -> None:
self._iface = iface

# pylint: disable=unused-argument
def wrap_socket(
def wrap_socket( # pylint: disable=unused-argument
self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None
) -> _FakeSSLSocket:
"""Return the same socket"""
Expand All @@ -99,7 +98,8 @@ def create_fake_ssl_context(
return _FakeSSLContext(iface)


_global_socketpool = {}
_global_connection_managers = {}
dhalbert marked this conversation as resolved.
Show resolved Hide resolved
_global_socketpools = {}
_global_ssl_contexts = {}


Expand All @@ -113,7 +113,7 @@ def get_radio_socketpool(radio):
* Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
"""
class_name = radio.__class__.__name__
if class_name not in _global_socketpool:
if class_name not in _global_socketpools:
if class_name == "Radio":
import ssl # pylint: disable=import-outside-toplevel

Expand Down Expand Up @@ -151,10 +151,10 @@ def get_radio_socketpool(radio):
else:
raise AttributeError(f"Unsupported radio class: {class_name}")

_global_socketpool[class_name] = pool
_global_socketpools[class_name] = pool
_global_ssl_contexts[class_name] = ssl_context

return _global_socketpool[class_name]
return _global_socketpools[class_name]


def get_radio_ssl_context(radio):
Expand Down Expand Up @@ -183,42 +183,75 @@ def __init__(
) -> None:
self._socket_pool = socket_pool
# Hang onto open sockets so that we can reuse them.
self._available_socket = {}
self._open_sockets = {}

def _free_sockets(self) -> None:
available_sockets = []
for socket, free in self._available_socket.items():
if free:
available_sockets.append(socket)
self._available_sockets = set()
self._key_by_managed_socket = {}
self._managed_socket_by_key = {}

def _free_sockets(self, force: bool = False) -> None:
# cloning lists since items are being removed
available_sockets = list(self._available_sockets)
for socket in available_sockets:
self.close_socket(socket)
if force:
open_sockets = list(self._managed_socket_by_key.values())
for socket in open_sockets:
self.close_socket(socket)

def _get_key_for_socket(self, socket):
def _get_connected_socket( # pylint: disable=too-many-arguments
self,
addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]],
host: str,
port: int,
timeout: float,
is_ssl: bool,
ssl_context: Optional[SSLContextType] = None,
):
try:
return next(
key for key, value in self._open_sockets.items() if value == socket
)
except StopIteration:
return None
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
except (OSError, RuntimeError) as exc:
return exc

if is_ssl:
socket = ssl_context.wrap_socket(socket, server_hostname=host)
connect_host = host
else:
connect_host = addr_info[-1][0]
socket.settimeout(timeout) # socket read timeout

try:
socket.connect((connect_host, port))
except (MemoryError, OSError) as exc:
socket.close()
return exc

return socket

@property
def available_socket_count(self) -> int:
"""Get the count of freeable open sockets"""
return len(self._available_sockets)

@property
def managed_socket_count(self) -> int:
"""Get the count of open sockets"""
return len(self._managed_socket_by_key)

def close_socket(self, socket: SocketType) -> None:
"""Close a previously opened socket."""
if socket not in self._open_sockets.values():
if socket not in self._managed_socket_by_key.values():
raise RuntimeError("Socket not managed")
key = self._get_key_for_socket(socket)
socket.close()
del self._available_socket[socket]
del self._open_sockets[key]
key = self._key_by_managed_socket.pop(socket)
del self._managed_socket_by_key[key]
if socket in self._available_sockets:
self._available_sockets.remove(socket)

def free_socket(self, socket: SocketType) -> None:
"""Mark a previously opened socket as available so it can be reused if needed."""
if socket not in self._open_sockets.values():
if socket not in self._managed_socket_by_key.values():
raise RuntimeError("Socket not managed")
self._available_socket[socket] = True
self._available_sockets.add(socket)

# pylint: disable=too-many-branches,too-many-locals,too-many-statements
def get_socket(
self,
host: str,
Expand All @@ -234,10 +267,10 @@ def get_socket(
if session_id:
session_id = str(session_id)
key = (host, port, proto, session_id)
if key in self._open_sockets:
socket = self._open_sockets[key]
if self._available_socket[socket]:
self._available_socket[socket] = False
if key in self._managed_socket_by_key:
socket = self._managed_socket_by_key[key]
if socket in self._available_sockets:
self._available_sockets.remove(socket)
return socket

raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}")
Expand All @@ -253,64 +286,63 @@ def get_socket(
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]

try_count = 0
socket = None
last_exc = None
while try_count < 2 and socket is None:
try_count += 1
if try_count > 1:
if any(
socket
for socket, free in self._available_socket.items()
if free is True
):
self._free_sockets()
else:
break
result = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
if isinstance(result, Exception):
# Got an error, if there are any available sockets, free them and try again
if self.available_socket_count:
self._free_sockets()
result = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
if isinstance(result, Exception):
raise RuntimeError(f"Error connecting socket: {result}") from result
Copy link

@RetiredWizard RetiredWizard Apr 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is a lot cleaner. I still worry about tossing the first error though, in my experience the first error is often the most important. What do you think about saving the result of the first _get_connect_socket call if it's an exception and then if the second call also returns an exception, printing the result before actually raising the error. Something like:

        if isinstance(result, Exception):
            if result_sav is not None:
                print(f"Error connecting  socket: {sys.exception(result_sav)}\nTrying again")
            raise RuntimeError(f"Error connecting socket: {result}") from result

This fakes the timing of events a bit since both errors would be displayed after both attempts fail, but has the advantage of keeping the output clean if the second attempt succeeds.

Edit: Actually if the results is an exception at the end then the first attempt must have failed so you wouldn't need to do the result_sav is not None test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RetiredWizard this is tough. I wouldn't want a core library printing. I think with a little thought, we could flag a debug mode that prints them all (maybe with a connection_manager_set_debug(socket). I also have a fear on too much bloat...

What about:

        last_result = ""
        result = self._get_connected_socket(
            addr_info, host, port, timeout, is_ssl, ssl_context
        )
        if isinstance(result, Exception):
            # Got an error, if there are any available sockets, free them and try again
            if self.available_socket_count:
                last_result = f", first error: {result}"
                self._free_sockets()
                result = self._get_connected_socket(
                    addr_info, host, port, timeout, is_ssl, ssl_context
                )
        if isinstance(result, Exception):
            raise RuntimeError(f"Error connecting socket: {result}{last_result}") from result

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would update this test:

def test_get_socket_runtime_error_ties_again_only_once():
    mock_pool = mocket.MocketPool()
    mock_socket_1 = mocket.Mocket()
    mock_socket_2 = mocket.Mocket()
    mock_pool.socket.side_effect = [
        mock_socket_1,
        RuntimeError("error 1"),
        RuntimeError("error 2"),
        RuntimeError("error 3"),
        mock_socket_2,
    ]

    free_sockets_mock = mock.Mock()
    connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)
    connection_manager._free_sockets = free_sockets_mock

    # get a socket and then mark as free
    socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
    assert socket == mock_socket_1
    connection_manager.free_socket(socket)
    free_sockets_mock.assert_not_called()

    # try to get a socket that returns a RuntimeError twice
    with pytest.raises(RuntimeError) as context:
        connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:")
    assert "Error connecting socket: error 2, first error: error 1" in str(context)
    free_sockets_mock.assert_called_once()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea If that does what it looks like it does, that's great. I also see the argument to drop for the sake of bloat though. Some of the WiFi boards have really thin available memory and a few bytes can make a big difference. I'll let you make the call 😁

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhalbert are you good with this change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems ok. You can save just save the exception (first_exception) and then only make a formatted string in the second if.

Yes, reducing code size is good wherever you can do it for frozen libraries.


try:
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
except OSError as exc:
last_exc = exc
continue
except RuntimeError as exc:
last_exc = exc
continue

if is_ssl:
socket = ssl_context.wrap_socket(socket, server_hostname=host)
connect_host = host
else:
connect_host = addr_info[-1][0]
socket.settimeout(timeout) # socket read timeout

try:
socket.connect((connect_host, port))
except MemoryError as exc:
last_exc = exc
socket.close()
socket = None
except OSError as exc:
last_exc = exc
socket.close()
socket = None

if socket is None:
raise RuntimeError(f"Error connecting socket: {last_exc}") from last_exc

self._available_socket[socket] = False
self._open_sockets[key] = socket
return socket
self._key_by_managed_socket[result] = key
self._managed_socket_by_key[key] = result
return result


# global helpers


_global_connection_manager = {}
def connection_manager_close_all(
socket_pool: Optional[SocketpoolModuleType] = None, release_references: bool = False
) -> None:
"""Close all open sockets for pool"""
if socket_pool:
socket_pools = [socket_pool]
else:
socket_pools = _global_connection_managers.keys()

for pool in socket_pools:
connection_manager = _global_connection_managers.get(pool, None)
if connection_manager is None:
raise RuntimeError("SocketPool not managed")

connection_manager._free_sockets(force=True) # pylint: disable=protected-access

if release_references:
radio_key = None
for radio_check, pool_check in _global_socketpools.items():
if pool == pool_check:
radio_key = radio_check
break

if radio_key:
if radio_key in _global_socketpools:
del _global_socketpools[radio_key]

if radio_key in _global_ssl_contexts:
del _global_ssl_contexts[radio_key]

if pool in _global_connection_managers:
del _global_connection_managers[pool]


def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager:
"""Get the ConnectionManager singleton for the given pool"""
if socket_pool not in _global_connection_manager:
_global_connection_manager[socket_pool] = ConnectionManager(socket_pool)
return _global_connection_manager[socket_pool]
if socket_pool not in _global_connection_managers:
_global_connection_managers[socket_pool] = ConnectionManager(socket_pool)
return _global_connection_managers[socket_pool]
32 changes: 28 additions & 4 deletions examples/connectionmanager_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,38 @@

# get request session
requests = adafruit_requests.Session(pool, ssl_context)
connection_manager = adafruit_connection_manager.get_connection_manager(pool)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output of this is so:

----------------------------------------
Nothing yet opened
Open Sockets: 0
Freeable Open Sockets: 0
----------------------------------------
Fetching from http://wifitest.adafruit.com/testwifi/index.html in a context handler
Text Response This is a test of Adafruit WiFi!
If you can read this, its working :)
----------------------------------------
1 request, opened and freed
Open Sockets: 1
Freeable Open Sockets: 1
----------------------------------------
Fetching from http://wifitest.adafruit.com/testwifi/index.html not in a context handler
----------------------------------------
1 request, opened but not freed
Open Sockets: 1
Freeable Open Sockets: 0
----------------------------------------
Closing everything in the pool
----------------------------------------
Everything closed
Open Sockets: 0
Freeable Open Sockets: 0

print("-" * 40)
print("Nothing yet opened")
print(f"Open Sockets: {connection_manager.managed_socket_count}")
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")

# make request
print("-" * 40)
print(f"Fetching from {TEXT_URL}")
print(f"Fetching from {TEXT_URL} in a context handler")
with requests.get(TEXT_URL) as response:
response_text = response.text
print(f"Text Response {response_text}")

print("-" * 40)
print("1 request, opened and freed")
print(f"Open Sockets: {connection_manager.managed_socket_count}")
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")

print("-" * 40)
print(f"Fetching from {TEXT_URL} not in a context handler")
response = requests.get(TEXT_URL)
response_text = response.text
response.close()

print(f"Text Response {response_text}")
print("-" * 40)
print("1 request, opened but not freed")
print(f"Open Sockets: {connection_manager.managed_socket_count}")
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")

print("-" * 40)
print("Closing everything in the pool")
adafruit_connection_manager.connection_manager_close_all(pool)

print("-" * 40)
print("Everything closed")
print(f"Open Sockets: {connection_manager.managed_socket_count}")
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")
8 changes: 4 additions & 4 deletions tests/close_socket_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ def test_close_socket():
socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
key = (mocket.MOCK_HOST_1, 80, "http:", None)
assert socket == mock_socket_1
assert socket in connection_manager._available_socket
assert key in connection_manager._open_sockets
assert socket not in connection_manager._available_sockets
assert key in connection_manager._managed_socket_by_key

# validate socket is no longer tracked
connection_manager.close_socket(socket)
assert socket not in connection_manager._available_socket
assert key not in connection_manager._open_sockets
assert socket not in connection_manager._available_sockets
assert key not in connection_manager._managed_socket_by_key


def test_close_socket_not_managed():
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def adafruit_wiznet5k_with_ssl_socket_module():
@pytest.fixture(autouse=True)
def reset_connection_manager(monkeypatch):
monkeypatch.setattr(
"adafruit_connection_manager._global_socketpool",
"adafruit_connection_manager._global_connection_managers",
{},
)
monkeypatch.setattr(
"adafruit_connection_manager._global_socketpools",
{},
)
monkeypatch.setattr(
Expand Down
Loading
Loading