Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

close(): fix exception and enable fast-close #156

Merged
merged 4 commits into from
Mar 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions adafruit_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,14 @@ class Response:

encoding = None

def __init__(self, sock: SocketType, session: "Session") -> None:
def __init__(
self, sock: SocketType, session: "Session", fast_close: bool = False
) -> None:
self.socket = sock
self.encoding = "utf-8"
self._cached = None
self._headers = {}
self._fast_close = fast_close

# _start_index and _receive_buffer are used when parsing headers.
# _receive_buffer will grow by 32 bytes everytime it is too small.
Expand Down Expand Up @@ -231,17 +234,18 @@ def close(self) -> None:
if not self.socket:
return
# Make sure we've read all of our response.
if self._cached is None:
if self._cached is None and not self._fast_close:
if self._remaining and self._remaining > 0:
self._throw_away(self._remaining)
elif self._chunked:
while True:
chunk_header = bytes(self._readto(b"\r\n")).split(b";", 1)[0]
if not chunk_header:
break
chunk_size = int(bytes(chunk_header), 16)
if chunk_size == 0:
break
self._throw_away(chunk_size + 2)
self._parse_headers()
if self._session:
# pylint: disable=protected-access
self._session._connection_manager.free_socket(self.socket)
Expand Down Expand Up @@ -361,11 +365,13 @@ def __init__(
socket_pool: SocketpoolModuleType,
ssl_context: Optional[SSLContextType] = None,
session_id: Optional[str] = None,
fast_close: Optional[bool] = False,
) -> None:
self._connection_manager = get_connection_manager(socket_pool)
self._ssl_context = ssl_context
self._session_id = session_id
self._last_response = None
self._fast_close = fast_close

@staticmethod
def _check_headers(headers: Dict[str, str]):
Expand Down Expand Up @@ -560,7 +566,7 @@ def request(
if not socket:
raise OutOfRetries("Repeated socket failures") from last_exc

resp = Response(socket, self) # our response
resp = Response(socket, self, fast_close=self._fast_close) # our response
if allow_redirects:
if "location" in resp.headers and 300 <= resp.status_code <= 399:
# a naive handler for redirects
Expand Down
Loading