diff --git a/tests/integration/test_proxy.py b/tests/integration/test_proxy.py index c49ea6b0..71045dd8 100644 --- a/tests/integration/test_proxy.py +++ b/tests/integration/test_proxy.py @@ -1,5 +1,6 @@ """Test using a proxy.""" +import asyncio import http.server import socketserver import threading @@ -36,6 +37,35 @@ def do_GET(self): self.end_headers() self.copyfile(upstream_response, self.wfile) + def do_CONNECT(self): + host, port = self.path.split(":") + + asyncio.run(self._tunnel(host, port, self.connection)) + + async def _tunnel(self, host, port, client_sock): + target_r, target_w = await asyncio.open_connection(host=host, port=port) + + self.send_response(http.HTTPStatus.OK) + self.end_headers() + + source_r, source_w = await asyncio.open_connection(sock=client_sock) + + async def channel(reader, writer): + while True: + data = await reader.read(1024) + if not data: + break + writer.write(data) + await writer.drain() + + writer.close() + await writer.wait_closed() + + await asyncio.gather( + channel(target_r, source_w), + channel(source_r, target_w), + ) + @pytest.fixture(scope="session") def proxy_server(): @@ -59,3 +89,23 @@ def test_use_proxy(tmpdir, httpbin, proxy_server): assert cassette_response.headers[key] == response.headers[key] assert cassette_response.headers == response.headers assert cassette.play_count == 1 + + +def test_use_https_proxy(tmpdir, httpbin_secure, proxy_server): + """Ensure that it works with an HTTPS proxy.""" + with vcr.use_cassette(str(tmpdir.join("proxy.yaml"))): + response = requests.get(httpbin_secure.url, proxies={"https": proxy_server}) + + with vcr.use_cassette(str(tmpdir.join("proxy.yaml")), mode="once") as cassette: + cassette_response = requests.get( + httpbin_secure.url, + proxies={"https": proxy_server}, + ) + + for key in set(cassette_response.headers.keys()) & set(response.headers.keys()): + assert cassette_response.headers[key] == response.headers[key] + assert cassette_response.headers == response.headers + assert cassette.play_count == 1 + + # The cassette url points to httpbin, not the proxy + assert cassette.requests[0].url == httpbin_secure.url + "/" diff --git a/vcr/stubs/__init__.py b/vcr/stubs/__init__.py index 4d4bb39d..336038ff 100644 --- a/vcr/stubs/__init__.py +++ b/vcr/stubs/__init__.py @@ -186,22 +186,34 @@ def _port_postfix(self): """ Returns empty string for the default port and ':port' otherwise """ - port = self.real_connection.port + port = ( + self.real_connection.port + if not self.real_connection._tunnel_host + else self.real_connection._tunnel_port + ) default_port = {"https": 443, "http": 80}[self._protocol] return f":{port}" if port != default_port else "" + def _real_host(self): + """Returns the request host""" + if self.real_connection._tunnel_host: + # The real connection is to an HTTPS proxy + return self.real_connection._tunnel_host + else: + return self.real_connection.host + def _uri(self, url): """Returns request absolute URI""" if url and not url.startswith("/"): # Then this must be a proxy request. return url - uri = f"{self._protocol}://{self.real_connection.host}{self._port_postfix()}{url}" + uri = f"{self._protocol}://{self._real_host()}{self._port_postfix()}{url}" log.debug("Absolute URI: %s", uri) return uri def _url(self, uri): """Returns request selector url from absolute URI""" - prefix = f"{self._protocol}://{self.real_connection.host}{self._port_postfix()}" + prefix = f"{self._protocol}://{self._real_host()}{self._port_postfix()}" return uri.replace(prefix, "", 1) def request(self, method, url, body=None, headers=None, *args, **kwargs):