From 5ea399dc341a2de8c06a9ba07e967008b48842cc Mon Sep 17 00:00:00 2001 From: Allison Karlitskaya Date: Thu, 15 Aug 2024 22:42:08 +0200 Subject: [PATCH] test: add a mock webserver in Python This is roughly a replacement for test-server. It's an outline for what a simple cockpit-ws in Python might look like (albeit without any of the necessary authentication code). It's already useful enough to run the majority of the existing QUnit, which means we can now run most unit tests without needing to build any C code. --- src/cockpit/jsonutil.py | 6 + test/conftest.py | 43 +++ test/pytest/mockdbusservice.py | 171 ++++++++++++ test/pytest/mockwebserver.py | 484 +++++++++++++++++++++++++++++++++ test/pytest/test_browser.py | 19 +- 5 files changed, 720 insertions(+), 3 deletions(-) create mode 100644 test/conftest.py create mode 100644 test/pytest/mockdbusservice.py create mode 100644 test/pytest/mockwebserver.py diff --git a/src/cockpit/jsonutil.py b/src/cockpit/jsonutil.py index 7df905c4e6c2..f4e2f1f21b54 100644 --- a/src/cockpit/jsonutil.py +++ b/src/cockpit/jsonutil.py @@ -83,6 +83,12 @@ def get_str(obj: JsonObject, key: str, default: Union[DT, _Empty] = _empty) -> U return _get(obj, lambda v: typechecked(v, str), key, default) +def get_str_map(obj: JsonObject, key: str, default: DT | _Empty = _empty) -> DT | Mapping[str, str]: + def as_str_map(value: JsonValue) -> Mapping[str, str]: + return {key: typechecked(value, str) for key, value in typechecked(value, dict).items()} + return _get(obj, as_str_map, key, default) + + def get_str_or_none(obj: JsonObject, key: str, default: Optional[str]) -> Optional[str]: return _get(obj, lambda v: None if v is None else typechecked(v, str), key, default) diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 000000000000..eb4b50969ae7 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,43 @@ +import os +import subprocess +from typing import Iterator + +import pytest + +from cockpit._vendor import systemd_ctypes + + +# run tests on a private user bus +@pytest.fixture(scope='session', autouse=True) +def mock_session_bus(tmp_path_factory: pytest.TempPathFactory) -> Iterator[None]: + # make sure nobody opened the user bus yet... + assert systemd_ctypes.Bus._default_user_instance is None + + tmpdir = tmp_path_factory.getbasetemp() + dbus_config = tmpdir / 'dbus-config' + dbus_addr = f'unix:path={tmpdir / "dbus_socket"}' + + dbus_config.write_text(fr""" + + + session + {dbus_addr} + + + + + + + + + + """) + dbus_daemon = subprocess.run( + ['dbus-daemon', f'--config-file={dbus_config}', '--print-pid'], stdout=subprocess.PIPE + ) + pid = int(dbus_daemon.stdout) + try: + os.environ['DBUS_SESSION_BUS_ADDRESS'] = dbus_addr + yield None + finally: + os.kill(pid, 9) diff --git a/test/pytest/mockdbusservice.py b/test/pytest/mockdbusservice.py new file mode 100644 index 000000000000..86db89d9afa2 --- /dev/null +++ b/test/pytest/mockdbusservice.py @@ -0,0 +1,171 @@ +import asyncio +import contextlib +import logging +import math +from collections.abc import AsyncIterator +from typing import Iterator + +from cockpit._vendor import systemd_ctypes + +logger = logging.getLogger(__name__) + + +# No introspection, manual handling of method calls +class borkety_Bork(systemd_ctypes.bus.BaseObject): + def message_received(self, message: systemd_ctypes.bus.BusMessage) -> bool: + signature = message.get_signature(True) # noqa:FBT003 + body = message.get_body() + logger.debug('got Bork message: %s %r', signature, body) + + if message.get_member() == 'Echo': + message.reply_method_return(signature, *body) + return True + + return False + + +class com_redhat_Cockpit_DBusTests_Frobber(systemd_ctypes.bus.Object): + finally_normal_name = systemd_ctypes.bus.Interface.Property('s', 'There aint no place like home') + readonly_property = systemd_ctypes.bus.Interface.Property('s', 'blah') + aay = systemd_ctypes.bus.Interface.Property('aay', [], name='aay') + ag = systemd_ctypes.bus.Interface.Property('ag', [], name='ag') + ao = systemd_ctypes.bus.Interface.Property('ao', [], name='ao') + as_ = systemd_ctypes.bus.Interface.Property('as', [], name='as') + ay = systemd_ctypes.bus.Interface.Property('ay', b'ABCabc\0', name='ay') + b = systemd_ctypes.bus.Interface.Property('b', value=False, name='b') + d = systemd_ctypes.bus.Interface.Property('d', 43, name='d') + g = systemd_ctypes.bus.Interface.Property('g', '', name='g') + i = systemd_ctypes.bus.Interface.Property('i', 0, name='i') + n = systemd_ctypes.bus.Interface.Property('n', 0, name='n') + o = systemd_ctypes.bus.Interface.Property('o', '/', name='o') + q = systemd_ctypes.bus.Interface.Property('q', 0, name='q') + s = systemd_ctypes.bus.Interface.Property('s', '', name='s') + t = systemd_ctypes.bus.Interface.Property('t', 0, name='t') + u = systemd_ctypes.bus.Interface.Property('u', 0, name='u') + x = systemd_ctypes.bus.Interface.Property('x', 0, name='x') + y = systemd_ctypes.bus.Interface.Property('y', 42, name='y') + + test_signal = systemd_ctypes.bus.Interface.Signal('i', 'as', 'ao', 'a{s(ii)}') + + @systemd_ctypes.bus.Interface.Method('', 'i') + def request_signal_emission(self, which_one: int) -> None: + del which_one + + self.test_signal( + 43, + ['foo', 'frobber'], + ['/foo', '/foo/bar'], + {'first': (42, 42), 'second': (43, 43)} + ) + + @systemd_ctypes.bus.Interface.Method('s', 's') + def hello_world(self, greeting: str) -> str: + return f"Word! You said `{greeting}'. I'm Skeleton, btw!" + + @systemd_ctypes.bus.Interface.Method('', '') + async def never_return(self) -> None: + await asyncio.sleep(1000000) + + @systemd_ctypes.bus.Interface.Method( + ['y', 'b', 'n', 'q', 'i', 'u', 'x', 't', 'd', 's', 'o', 'g', 'ay'], + ['y', 'b', 'n', 'q', 'i', 'u', 'x', 't', 'd', 's', 'o', 'g', 'ay'] + ) + def test_primitive_types( + self, + val_byte, val_boolean, + val_int16, val_uint16, val_int32, val_uint32, val_int64, val_uint64, + val_double, + val_string, val_objpath, val_signature, + val_bytestring + ): + return [ + val_byte + 10, + not val_boolean, + 100 + val_int16, + 1000 + val_uint16, + 10000 + val_int32, + 100000 + val_uint32, + 1000000 + val_int64, + 10000000 + val_uint64, + val_double / math.pi, + f"Word! You said `{val_string}'. Rock'n'roll!", + f"/modified{val_objpath}", + f"assgit{val_signature}", + b"bytestring!\xff\0" + ] + + @systemd_ctypes.bus.Interface.Method( + ['s'], + ["a{ss}", "a{s(ii)}", "(iss)", "as", "ao", "ag", "aay"] + ) + def test_non_primitive_types( + self, + dict_s_to_s, + dict_s_to_pairs, + a_struct, + array_of_strings, + array_of_objpaths, + array_of_signatures, + array_of_bytestrings + ): + return ( + f'{dict_s_to_s}{dict_s_to_pairs}{a_struct}' + f'array_of_strings: [{", ".join(array_of_strings)}] ' + f'array_of_objpaths: [{", ".join(array_of_objpaths)}] ' + f'array_of_signatures: [signature {", ".join(f"'{sig}'" for sig in array_of_signatures)}] ' + f'array_of_bytestrings: [{", ".join(x[:-1].decode() for x in array_of_bytestrings)}] ' + ) + + +@contextlib.contextmanager +def mock_service_export(bus: systemd_ctypes.Bus) -> Iterator[None]: + slots = [ + bus.add_object('/otree/frobber', com_redhat_Cockpit_DBusTests_Frobber()), + bus.add_object('/otree/different', com_redhat_Cockpit_DBusTests_Frobber()), + bus.add_object('/bork', borkety_Bork()) + ] + + yield + + for slot in slots: + slot.cancel() + + +@contextlib.asynccontextmanager +async def well_known_name(bus: systemd_ctypes.Bus, name: str, flags: int = 0) -> AsyncIterator[None]: + result, = await bus.call_method_async( + 'org.freedesktop.DBus', '/org/freedesktop/DBus', 'org.freedesktop.DBus', 'RequestName', 'su', name, flags + ) + if result != 1: + raise RuntimeError(f'Cannot register name {name}: {result}') + + try: + yield + + finally: + result, = await bus.call_method_async( + 'org.freedesktop.DBus', '/org/freedesktop/DBus', 'org.freedesktop.DBus', 'ReleaseName', 's', name + ) + if result != 1: + raise RuntimeError(f'Cannot release name {name}: {result}') + + +@contextlib.asynccontextmanager +async def mock_dbus_service_on_user_bus() -> AsyncIterator[None]: + user = systemd_ctypes.Bus.default_user() + async with ( + well_known_name(user, 'com.redhat.Cockpit.DBusTests.Test'), + well_known_name(user, 'com.redhat.Cockpit.DBusTests.Second'), + ): + with mock_service_export(user): + yield + + +async def main(): + async with mock_dbus_service_on_user_bus(): + print('Mock service running. Ctrl+C to exit.') + await asyncio.sleep(2 << 30) # "a long time." + + +if __name__ == '__main__': + systemd_ctypes.run_async(main()) diff --git a/test/pytest/mockwebserver.py b/test/pytest/mockwebserver.py new file mode 100644 index 000000000000..c0cc5d943bb1 --- /dev/null +++ b/test/pytest/mockwebserver.py @@ -0,0 +1,484 @@ +# This file is part of Cockpit. +# +# Copyright (C) 2024 Red Hat, Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# https://github.com/astral-sh/ruff/issues/10980#issuecomment-2219615329 +# ruff: noqa: RUF029 + +import argparse +import asyncio +import binascii +import contextlib +import json +import logging +import os +import socket +import weakref +from collections.abc import AsyncIterator, Awaitable, Callable, Mapping +from pathlib import Path +from typing import ClassVar, NamedTuple, Self + +import aiohttp +from aiohttp import WSCloseCode, web +from yarl import URL + +from cockpit._vendor import systemd_ctypes +from cockpit.bridge import Bridge +from cockpit.jsonutil import JsonObject, JsonValue, create_object, get_enum, get_int, get_str, get_str_map +from cockpit.protocol import CockpitProblem, CockpitProtocolError + +from .mockdbusservice import mock_dbus_service_on_user_bus + +logger = logging.getLogger(__name__) + +websockets = web.AppKey("websockets", weakref.WeakSet) + + +class TextChannelOrigin(NamedTuple): + enqueue: Callable[[str | None], None] + + +class BinaryChannelOrigin(NamedTuple): + enqueue: Callable[[str | bytes | None], None] + + +class ExternalChannelOrigin(NamedTuple): + enqueue: Callable[[JsonObject | bytes], None] + + +ChannelOrigin = TextChannelOrigin | BinaryChannelOrigin | ExternalChannelOrigin + + +class MultiplexTransport(asyncio.Transport): + transports: ClassVar[dict[str, Self]] = {} + last_id: ClassVar[int] = 0 + + def __init__(self, protocol: asyncio.Protocol, origin: ChannelOrigin): + self.origins: dict[str | None, ChannelOrigin] = {None: origin} + self.channel_sequence = 0 + self.protocol = protocol + self.protocol.connection_made(self) + + self.csrf_token = f'token{MultiplexTransport.last_id}' + MultiplexTransport.transports[self.csrf_token] = self + MultiplexTransport.last_id += 1 + + def write(self, data: bytes) -> None: + # We know that cockpit.protocol always writes complete frames + header, _, frame = data.partition(b'\n') + assert int(header) == len(frame) + + channel_id, _, body = frame.partition(b'\n') + if channel_id: + # data message on the named channel + origin = self.origins.get(channel_id.decode()) + match origin: + case BinaryChannelOrigin(enqueue): + enqueue(frame) + case TextChannelOrigin(enqueue): + enqueue(frame.decode()) + case ExternalChannelOrigin(enqueue): + enqueue(body) + + else: + # control message (channel=None for transport control) + message = json.loads(body) + channel = get_str(message, 'channel', None) + origin = self.origins.get(channel) + + match origin: + case BinaryChannelOrigin(enqueue) | TextChannelOrigin(enqueue): + enqueue(frame.decode()) + case ExternalChannelOrigin(enqueue): + enqueue(message) + + print(message) + if origin is not None and get_str(message, 'command') == 'close': + del self.origins[channel] + + def register_origin(self, origin: ChannelOrigin, channel: str | None = None) -> str: + # normal channels get their IDs allocated in cockpit.js + + if channel is None: + # external channels get their IDs allocated by us + channel = f'external{self.channel_sequence}' + self.channel_sequence += 1 + + self.origins[channel] = origin + + return channel + + def data_received(self, data: bytes) -> None: + # cockpit.protocol expects a frame length header + header = f'{len(data)}\n'.encode() + self.protocol.data_received(header + data) + + def control_received(self, message: JsonObject) -> None: + self.data_received(b'\n' + json.dumps(message).encode()) + + def close(self) -> None: + assert MultiplexTransport.transports.pop(self.csrf_token) is self + + +class CockpitWebSocket(web.WebSocketResponse): + def __init__(self): + self.outgoing_queue = asyncio.Queue[str | bytes | None]() + super().__init__(protocols=['cockpit1']) + + async def send_control(self, _msg: JsonObject | None = None, **kwargs: JsonValue) -> None: + await self.send_str('\n' + json.dumps(create_object(_msg, kwargs))) + + async def process_outgoing_queue(self, queue: asyncio.Queue[str | bytes | None]) -> None: + while True: + item = await queue.get() + if isinstance(item, str): + await self.send_str(item) + elif isinstance(item, bytes): + await self.send_bytes(item) + else: + break + + async def communicate(self, request: web.Request) -> None: + text_origin = TextChannelOrigin(self.outgoing_queue.put_nowait) + binary_origin = BinaryChannelOrigin(self.outgoing_queue.put_nowait) + + try: + bridge = Bridge(argparse.Namespace(privileged=False, beipack=False)) + transport = MultiplexTransport(bridge, text_origin) + + # wait for the bridge to send its "init" + bridge_init = await self.outgoing_queue.get() + del bridge_init + + # send our "init" to the websocket + await self.prepare(request) + await self.send_control( + command='init', version=1, host='localhost', + channel_seed='test-server', csrf_token=transport.csrf_token, + capabilities=['multi', 'credentials', 'binary'], + system={'version': '0'} + ) + + # receive "init" from the websocket + try: + assert await self.receive_json() == {'command': 'init', 'version': 1} + except (TypeError, json.JSONDecodeError, AssertionError) as exc: + raise CockpitProtocolError('expected init message') from exc + + # send "init" to the bridge + # TODO: explicit-superuser handling + transport.data_received(b'\n' + json.dumps({ + "command": "init", + "version": 1, + "host": "localhost" + }).encode()) + + write_task = asyncio.create_task(self.process_outgoing_queue(self.outgoing_queue)) + + try: + async for msg in self: + if msg.type == aiohttp.WSMsgType.TEXT: + frame = msg.data + if frame.startswith('\n'): + control = json.loads(frame) + command = get_str(control, 'command') + channel = get_str(control, 'channel', None) + if command == 'open': + if channel is None: + raise CockpitProtocolError('open message without channel') + binary = get_enum(control, 'binary', ['raw'], None) == 'raw' + transport.register_origin(binary_origin if binary else text_origin, channel) + transport.data_received(frame.encode()) + elif msg.type == aiohttp.WSMsgType.BINARY: + transport.data_received(msg.data) + else: + raise CockpitProtocolError(f'strange websocket message {msg!s}') + finally: + self.outgoing_queue.put_nowait(None) + await write_task + + except CockpitProblem as exc: + if not self.closed: + await self.send_control(exc.get_attrs(), command='close') + + +routes = web.RouteTableDef() + + +@routes.get(r'/favicon.ico') +async def favicon_ico(request: web.Request) -> web.FileResponse: + del request + return web.FileResponse('src/branding/default/favicon.ico') + + +SPLIT_UTF8_FRAMES = [ + b"initial", + # split an é in the middle + b"first half \xc3", + b"\xa9 second half", + b"final" +] + + +@routes.get(r'/mock/expect-warnings') +@routes.get(r'/mock/dont-expect-warnings') +async def mock_expect_warnings(_request: web.Request) -> web.Response: + # no op — only for compatibility with C test-server + return web.Response(status=200, text='OK') + + +@routes.get(r'/mock/info') +async def mock_info(_request: web.Request) -> web.Response: + return web.json_response({ + 'pybridge': True, + 'skip_slow_tests': 'COCKPIT_SKIP_SLOW_TESTS' in os.environ + }) + + +@routes.get(r'/mock/stream') +async def mock_stream(request: web.Request) -> web.StreamResponse: + response = web.StreamResponse() + await response.prepare(request) + + for i in range(10): + await response.write(f'{i} '.encode()) + + return response + + +@routes.get(r'/mock/split-utf8') +async def mock_split_utf8(request: web.Request) -> web.StreamResponse: + response = web.StreamResponse() + await response.prepare(request) + + for chunk in SPLIT_UTF8_FRAMES: + await response.write(chunk) + + return response + + +@routes.get(r'/mock/truncated-utf8') +async def mock_truncated_utf8(request: web.Request) -> web.StreamResponse: + response = web.StreamResponse() + await response.prepare(request) + + for chunk in SPLIT_UTF8_FRAMES[0:2]: + await response.write(chunk) + + return response + + +@routes.get(r'/mock/headers') +async def mock_headers(request: web.Request) -> web.Response: + headers = {k: v for k, v in request.headers.items() if k.startswith('Header')} + headers['Header3'] = 'three' + headers['Header4'] = 'marmalade' + + return web.Response(status=201, text='Yoo Hoo', headers=headers) + + +@routes.get(r'/mock/host') +async def mock_host(request: web.Request) -> web.Response: + return web.Response(status=201, text='Yoo Hoo', headers={'Host': request.headers['Host']}) + + +@routes.get(r'/mock/headonly') +async def mock_headonly(request: web.Request) -> web.Response: + if request.method != 'HEAD': + return web.Response(status=400, reason="Only HEAD allowed on this path") + + input_data = request.headers.get('InputData') + if not input_data: + return web.Response(status=400, reason="Requires InputData header") + + return web.Response(status=200, text='OK', headers={'InputDataLength': str(len(input_data))}) + + +@routes.get(r'/mock/qs') +async def mock_qs(request: web.Request) -> web.Response: + return web.Response(text=request.query_string.replace(' ', '+')) + + +@routes.get(r'/cockpit/channel/{csrf_token}') +async def cockpit_channel(request: web.Request) -> web.StreamResponse: + try: + transport = MultiplexTransport.transports[request.match_info['csrf_token']] + except KeyError: + return web.Response(status=404) + + # Decode the request + try: + options = json.loads(binascii.a2b_base64(request.query_string)) + except (json.JSONDecodeError, binascii.Error) as exc: + return web.Response(status=400, reason=f'Invalid query string {exc!s}') + + binary = get_enum(options, 'binary', ['raw'], None) == 'raw' + websocket = request.headers.get('Upgrade', '').lower() == 'websocket' + + # Open the channel, requesting data send to our queue + queue = asyncio.Queue[JsonObject | bytes]() + channel = transport.register_origin(ExternalChannelOrigin(queue.put_nowait)) + transport.control_received({**options, 'command': 'open', 'channel': channel, 'flow-control': True}) + + # The first thing the channel sends back will be 'ready' or 'close' + open_result = await queue.get() + assert isinstance(open_result, Mapping) + if get_str(open_result, 'command') != 'ready': + return web.json_response(open_result, status=400, reason='Failed to open channel') + + # Start streaming the result. + if websocket: + response = web.WebSocketResponse() + await response.prepare(request) + + else: + # Send the 'external' field back as the HTTP headers... + headers = {**get_str_map(options, 'external', {})} + + if 'Content-Type' not in headers: + headers['Content-Type'] = 'application/octet-stream' if binary else 'text/plain' + + # ...plus this, if we have it. + if size_hint := get_int(open_result, 'size-hint', None): + headers['Content-Length'] = f'{size_hint}' + + response = web.StreamResponse(status=200, headers=headers) + await response.prepare(request) + + # Now, handle the data we receive + while item := await queue.get(): + match item: + case Mapping(): + match get_str(item, 'command'): + case 'ping': + transport.control_received({**item, 'command': 'pong'}) + case 'close' | 'done': + break + + case bytes(): + await response.write(item) + + return response + + +@routes.get(r'/cockpit/socket') +async def cockpit_socket(request: web.Request) -> web.WebSocketResponse: + ws = CockpitWebSocket() + request.app[websockets].add(ws) + await ws.communicate(request) + return ws + + +@routes.get('/') +async def index(_request: web.Request) -> web.Response: + cases = Path('qunit').rglob('test-*.html') + + result = ( + """ + + + Test cases + + + + + + """ + ) + + return web.Response(text=result, content_type='text/html') + + +@routes.get(r'/{name:(pkg|dist|qunit)/.+}') +async def serve_file(request: web.Request) -> web.FileResponse: + path = Path('.') / request.match_info['name'] + return web.FileResponse(path) + + +COMMON_HEADERS = { + "Cross-Origin-Resource-Policy": "same-origin", + "Referrer-Policy": "no-referrer", + "X-Content-Type-Options": "nosniff", + "X-DNS-Prefetch-Control": "off", + "X-Frame-Options": "sameorigin", +} + + +@web.middleware +async def cockpit_middleware( + request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]] +) -> web.StreamResponse: + try: + response = await handler(request) + except web.HTTPException as ex: + response = web.Response( + status=ex.status, reason=ex.reason, text=f'

{ex.reason}

', content_type='text/html' + ) + + response.headers.update(COMMON_HEADERS) + return response + + +@contextlib.asynccontextmanager +async def mock_webserver(addr: str = '127.0.0.1', port: int = 0) -> AsyncIterator[URL]: + async with mock_dbus_service_on_user_bus(): + app = web.Application(middlewares=[cockpit_middleware]) + + # https://docs.aiohttp.org/en/stable/web_advanced.html#websocket-shutdown + async def on_shutdown(app: web.Application): + for ws in set(app[websockets]): + await ws.close(code=WSCloseCode.GOING_AWAY, message="Server shutdown") + app[websockets] = weakref.WeakSet() + app.on_shutdown.append(on_shutdown) + app.add_routes(routes) + + runner = web.AppRunner(app) + await runner.setup() + + listener = socket.socket() + listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listener.bind((addr, port)) + listener.listen() + site = web.SockSite(runner, listener) + await site.start() + + addr, port = listener.getsockname() + yield URL(f'http://{addr}:{port}/') + + logger.debug('cleaning up mock webserver') + await runner.cleanup() + logger.debug('cleaning up mock webserver complete') + + +async def main() -> None: + parser = argparse.ArgumentParser(description="Serve a single git repository via HTTP") + parser.add_argument('--addr', '-a', default='127.0.0.1', help="Address to bind to") + parser.add_argument('--port', '-p', type=int, default=8080, help="Port number to bind to") + args = parser.parse_args() + + logging.basicConfig(level=logging.DEBUG) + + async with mock_webserver(args.addr, args.port) as url: + print(f"\n {url}\n\nCtrl+C to exit.") + await asyncio.sleep(1000000) + + +if __name__ == '__main__': + systemd_ctypes.run_async(main()) diff --git a/test/pytest/test_browser.py b/test/pytest/test_browser.py index f6cdf7c4de7c..650e5a5a7002 100644 --- a/test/pytest/test_browser.py +++ b/test/pytest/test_browser.py @@ -11,6 +11,8 @@ from webdriver_bidi import ChromiumBidi from yarl import URL +from .mockwebserver import mock_webserver + SRCDIR = os.path.realpath(f'{__file__}/../../..') BUILDDIR = os.environ.get('abs_builddir', SRCDIR) @@ -22,6 +24,14 @@ 'base1/test-websocket.html', } +MOCK_WEBSERVER_XFAIL = { + 'base1/test-dbus.html', + 'base1/test-external.html', + 'base1/test-http.html', + 'base1/test-stream.html', + 'shell/machines/test-machines.html', +} + @contextlib.asynccontextmanager async def spawn_test_server() -> AsyncIterator[URL]: # noqa:RUF029 @@ -53,14 +63,17 @@ async def spawn_test_server() -> AsyncIterator[URL]: # noqa:RUF029 @pytest.mark.asyncio @pytest.mark.parametrize('html', glob.glob('**/test-*.html', root_dir=f'{SRCDIR}/qunit', recursive=True)) -async def test_browser(coverage_report: CoverageReport, html: str) -> None: +@pytest.mark.parametrize('server', ['test-server', 'mock_webserver']) +async def test_browser(coverage_report: CoverageReport, html: str, server: str) -> None: if html in SKIP: pytest.skip() elif html in XFAIL: pytest.xfail() + elif server == 'mock_webserver' and html in MOCK_WEBSERVER_XFAIL: + pytest.xfail() async with ( - spawn_test_server() as base_url, + (mock_webserver if server == 'mock_webserver' else spawn_test_server)() as base_url, ChromiumBidi(headless=os.environ.get('TEST_SHOW_BROWSER', '0') == '0') as browser ): await browser.cdp("Profiler.enable") @@ -119,4 +132,4 @@ async def test_timeformat_timezones( coverage_report: CoverageReport, tz: str, monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.setenv('TZ', tz) - await test_browser(coverage_report, 'base1/test-timeformat.html') + await test_browser(coverage_report, 'base1/test-timeformat.html', 'mock_webserver')