From 7054d62600efccfebe4031bb97fc1a094f584b16 Mon Sep 17 00:00:00 2001 From: Carl Lundin <108372512+clundin25@users.noreply.github.com> Date: Tue, 23 Apr 2024 13:32:17 -0700 Subject: [PATCH] fix: Clean up local server socket on error (#339) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This resolves https://togithub.com/googleapis/google-auth-library-python-oauthlib/issues/338. Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/google-auth-library-python-oauthlib/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes # 🦕 --- google_auth_oauthlib/flow.py | 52 +++++++++++++++++++----------------- tests/unit/test_flow.py | 15 +++++++++++ 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/google_auth_oauthlib/flow.py b/google_auth_oauthlib/flow.py index c5d8bce..e564ca4 100644 --- a/google_auth_oauthlib/flow.py +++ b/google_auth_oauthlib/flow.py @@ -433,31 +433,33 @@ def run_local_server( bind_addr or host, port, wsgi_app, handler_class=_WSGIRequestHandler ) - redirect_uri_format = ( - "http://{}:{}/" if redirect_uri_trailing_slash else "http://{}:{}" - ) - self.redirect_uri = redirect_uri_format.format(host, local_server.server_port) - auth_url, _ = self.authorization_url(**kwargs) - - if open_browser: - # if browser is None it defaults to default browser - webbrowser.get(browser).open(auth_url, new=1, autoraise=True) - - if authorization_prompt_message: - print(authorization_prompt_message.format(url=auth_url)) - - local_server.timeout = timeout_seconds - local_server.handle_request() - - # Note: using https here because oauthlib is very picky that - # OAuth 2.0 should only occur over https. - authorization_response = wsgi_app.last_request_uri.replace("http", "https") - self.fetch_token( - authorization_response=authorization_response, audience=token_audience - ) - - # This closes the socket - local_server.server_close() + try: + redirect_uri_format = ( + "http://{}:{}/" if redirect_uri_trailing_slash else "http://{}:{}" + ) + self.redirect_uri = redirect_uri_format.format( + host, local_server.server_port + ) + auth_url, _ = self.authorization_url(**kwargs) + + if open_browser: + # if browser is None it defaults to default browser + webbrowser.get(browser).open(auth_url, new=1, autoraise=True) + + if authorization_prompt_message: + print(authorization_prompt_message.format(url=auth_url)) + + local_server.timeout = timeout_seconds + local_server.handle_request() + + # Note: using https here because oauthlib is very picky that + # OAuth 2.0 should only occur over https. + authorization_response = wsgi_app.last_request_uri.replace("http", "https") + self.fetch_token( + authorization_response=authorization_response, audience=token_audience + ) + finally: + local_server.server_close() return self.credentials diff --git a/tests/unit/test_flow.py b/tests/unit/test_flow.py index 3e61fcd..a3314d1 100644 --- a/tests/unit/test_flow.py +++ b/tests/unit/test_flow.py @@ -25,6 +25,7 @@ import pytest import requests import urllib +import webbrowser from google_auth_oauthlib import flow @@ -440,3 +441,17 @@ def test_run_local_server_occupied_port( with pytest.raises(OSError) as exc: instance.run_local_server(port=port) assert "address already in use" in exc.strerror.lower() + + @mock.patch("google_auth_oauthlib.flow.webbrowser.get", autospec=True) + @mock.patch("wsgiref.simple_server.make_server", autospec=True) + def test_local_server_socket_cleanup( + self, make_server_mock, webbrowser_mock, instance + ): + server_mock = mock.MagicMock() + make_server_mock.return_value = server_mock + webbrowser_mock.side_effect = webbrowser.Error("Browser not found") + + with pytest.raises(webbrowser.Error): + instance.run_local_server() + + server_mock.server_close.assert_called_once()