diff --git a/tests/integration/test_proxy.py b/tests/integration/test_proxy.py index 7366d33e..5ceae38e 100644 --- a/tests/integration/test_proxy.py +++ b/tests/integration/test_proxy.py @@ -1,5 +1,6 @@ """Test using a proxy.""" +import asyncio import http.server import socketserver import threading @@ -36,10 +37,39 @@ def do_GET(self): self.end_headers() self.copyfile(upstream_response, self.wfile) + def do_CONNECT(self): + host, port = self.path.split(":") + + asyncio.run(self._tunnel(host, port, self.connection)) + + async def _tunnel(self, host, port, client_sock): + target_r, target_w = await asyncio.open_connection(host=host, port=port) + + self.send_response(http.HTTPStatus.OK) + self.end_headers() + + source_r, source_w = await asyncio.open_connection(sock=client_sock) + + async def channel(reader, writer): + while True: + data = await reader.read(1024) + if not data: + break + writer.write(data) + await writer.drain() + + writer.close() + await writer.wait_closed() + + await asyncio.gather( + channel(target_r, source_w), + channel(source_r, target_w), + ) + @pytest.fixture(scope="session") def proxy_server(): - httpd = socketserver.ThreadingTCPServer(("", 0), Proxy) + httpd = socketserver.ThreadingTCPServer(("127.0.0.3", 0), Proxy) proxy_process = threading.Thread(target=httpd.serve_forever) proxy_process.start() yield "http://{}:{}".format(*httpd.server_address) @@ -59,3 +89,22 @@ def test_use_proxy(tmpdir, httpbin, proxy_server): assert cassette_response.headers[key] == response.headers[key] assert cassette_response.headers == response.headers assert cassette.play_count == 1 + + +def test_use_https_proxy(tmpdir, httpbin_secure, proxy_server): + """Ensure that it works with an HTTPS proxy.""" + with vcr.use_cassette(str(tmpdir.join("proxy.yaml"))): + response = requests.get(httpbin_secure.url, proxies={"https": proxy_server}) + + with vcr.use_cassette(str(tmpdir.join("proxy.yaml")), mode="once") as cassette: + cassette_response = requests.get( + httpbin_secure.url, proxies={"https": proxy_server} + ) + + for key in set(cassette_response.headers.keys()) & set(response.headers.keys()): + assert cassette_response.headers[key] == response.headers[key] + assert cassette_response.headers == response.headers + assert cassette.play_count == 1 + + # The cassette url points to httpbin, not the proxy + assert cassette.requests[0].url == httpbin_secure.url + "/"