From 83206e884b0e007dfe7febc6b8373b4775c0639a Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Thu, 16 May 2024 00:41:23 -0400 Subject: [PATCH] feat(anta): limit the number of tests run concurrently --- anta/runner.py | 170 +++++++++++++++++++++++++++++-------- tests/units/test_runner.py | 9 +- 2 files changed, 139 insertions(+), 40 deletions(-) diff --git a/anta/runner.py b/anta/runner.py index dcb2d962e..e011dd577 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -18,7 +18,8 @@ from anta.tools import Catchtime, cprofile if TYPE_CHECKING: - from collections.abc import Coroutine + from asyncio import Task + from collections.abc import AsyncGenerator, Coroutine from anta.catalog import AntaCatalog, AntaTestDefinition from anta.device import AntaDevice @@ -29,6 +30,30 @@ logger = logging.getLogger(__name__) DEFAULT_NOFILE = 16384 +"""Default number of open file descriptors for the ANTA process.""" +DEFAULT_MAX_CONCURRENCY = 10000 +"""Default maximum number of tests to run concurrently.""" +DEFAULT_MAX_CONNECTIONS = 100 +"""Default underlying HTTPX client maximum number of connections per device.""" + + +def adjust_max_concurrency() -> int: + """Adjust the maximum number of tests (coroutines) to run concurrently. + + The limit is set to the value of the ANTA_MAX_CONCURRENCY environment variable. + + If the `ANTA_MAX_CONCURRENCY` environment variable is not set or is invalid, `DEFAULT_MAX_CONCURRENCY` is used. + + Returns + ------- + The maximum number of tests to run concurrently. + """ + try: + max_concurrency = int(os.environ.get("ANTA_MAX_CONCURRENCY", DEFAULT_MAX_CONCURRENCY)) + except ValueError as exception: + logger.warning("The ANTA_MAX_CONCURRENCY environment variable value is invalid: %s\nDefault to %s.", exc_to_str(exception), DEFAULT_MAX_CONCURRENCY) + max_concurrency = DEFAULT_MAX_CONCURRENCY + return max_concurrency def adjust_rlimit_nofile() -> tuple[int, int]: @@ -40,7 +65,6 @@ def adjust_rlimit_nofile() -> tuple[int, int]: Returns ------- - tuple[int, int] The new soft and hard limits for open file descriptors. """ try: @@ -77,6 +101,61 @@ def log_cache_statistics(devices: list[AntaDevice]) -> None: logger.info("Caching is not enabled on %s", device.name) +async def run(tests_generator: AsyncGenerator[Coroutine[Any, Any, TestResult], None], limit: int) -> AsyncGenerator[TestResult, None]: + """Run tests with a concurrency limit. + + This function takes an asynchronous generator of test coroutines and runs them + with a limit on the number of concurrent tests. It yields test results as each + test completes. + + Inspired by: https://death.andgravity.com/limit-concurrency + + Parameters + ---------- + tests_generator + An asynchronous generator that yields test coroutines. + limit + The maximum number of concurrent tests to run. + + Yields + ------ + The result of each completed test. + """ + # NOTE: The `aiter` built-in function is not available in Python 3.9 + aws = tests_generator.__aiter__() # pylint: disable=unnecessary-dunder-call + aws_ended = False + pending: set[Task[TestResult]] = set() + + while pending or not aws_ended: + # Add tests to the pending set until the limit is reached or no more tests are available + while len(pending) < limit and not aws_ended: + try: + # NOTE: The `anext` built-in function is not available in Python 3.9 + aw = await aws.__anext__() # pylint: disable=unnecessary-dunder-call + except StopAsyncIteration: # noqa: PERF203 + aws_ended = True + logger.debug("All tests have been added to the pending set.") + else: + # Ensure the coroutine is scheduled to run and add it to the pending set + pending.add(asyncio.create_task(aw)) + logger.debug("Added a test to the pending set: %s", aw) + + if len(pending) >= limit: + logger.debug("Concurrency limit reached: %s tests running. Waiting for tests to complete.", limit) + + if not pending: + logger.debug("No pending tests and all tests have been processed. Exiting.") + return + + # Wait for at least one of the pending tests to complete + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + logger.debug("Completed %s test(s). Pending count: %s", len(done), len(pending)) + + # Yield results of completed tests + while done: + yield await done.pop() + + async def setup_inventory(inventory: AntaInventory, tags: set[str] | None, devices: set[str] | None, *, established_only: bool) -> AntaInventory | None: """Set up the inventory for the ANTA run. @@ -93,8 +172,7 @@ async def setup_inventory(inventory: AntaInventory, tags: set[str] | None, devic Returns ------- - AntaInventory | None - The filtered inventory or None if there are no devices to run tests on. + The filtered AntaInventory or None if there are no devices to run tests on. """ if len(inventory) == 0: logger.info("The inventory is empty, exiting") @@ -119,10 +197,10 @@ async def setup_inventory(inventory: AntaInventory, tags: set[str] | None, devic return selected_inventory -def prepare_tests( +def setup_tests( inventory: AntaInventory, catalog: AntaCatalog, tests: set[str] | None, tags: set[str] | None -) -> defaultdict[AntaDevice, set[AntaTestDefinition]] | None: - """Prepare the tests to run. +) -> tuple[int, defaultdict[AntaDevice, set[AntaTestDefinition]] | None]: + """Set up the tests for the ANTA run. Parameters ---------- @@ -137,17 +215,17 @@ def prepare_tests( Returns ------- - defaultdict[AntaDevice, set[AntaTestDefinition]] | None - A mapping of devices to the tests to run or None if there are no tests to run. + The total number of tests and a mapping of devices to the tests to run or None if there are no tests to run. """ # Build indexes for the catalog. If `tests` is set, filter the indexes based on these tests catalog.build_indexes(filtered_tests=tests) + total_tests = 0 + # Using a set to avoid inserting duplicate tests device_to_tests: defaultdict[AntaDevice, set[AntaTestDefinition]] = defaultdict(set) - # Create AntaTestRunner tuples from the tags - final_tests_count = 0 + # Create the mapping of devices to the tests to run for device in inventory.devices: if tags: if not any(tag in device.tags for tag in tags): @@ -160,40 +238,42 @@ def prepare_tests( # Add the tests with matching tags from device tags device_to_tests[device].update(catalog.get_tests_by_tags(device.tags)) - final_tests_count += len(device_to_tests[device]) + total_tests += len(device_to_tests[device]) - if len(device_to_tests.values()) == 0: + if total_tests == 0: msg = ( f"There are no tests{f' matching the tags {tags} ' if tags else ' '}to run in the current test catalog and device inventory, please verify your inputs." ) logger.warning(msg) - return None + return total_tests, None - return device_to_tests + return total_tests, device_to_tests -def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinition]], manager: ResultManager) -> list[Coroutine[Any, Any, TestResult]]: +async def test_generator( + selected_tests: defaultdict[AntaDevice, set[AntaTestDefinition]], manager: ResultManager +) -> AsyncGenerator[Coroutine[Any, Any, TestResult], None]: """Get the coroutines for the ANTA run. + It creates an async generator of coroutines which are created by the `test` method of the AntaTest instances. Each coroutine is a test to run. + Parameters ---------- selected_tests - A mapping of devices to the tests to run. The selected tests are generated by the `prepare_tests` function. + A mapping of devices to the tests to run. The selected tests are created by the `setup_tests` function. manager A ResultManager - Returns - ------- - list[Coroutine[Any, Any, TestResult]] - The list of coroutines to run. + Yields + ------ + The coroutine (test) to run. """ - coros = [] for device, test_definitions in selected_tests.items(): for test in test_definitions: try: test_instance = test.test(device=device, inputs=test.inputs) manager.add(test_instance.result) - coros.append(test_instance.test()) + coroutine = test_instance.test() except Exception as e: # noqa: PERF203, BLE001 # An AntaTest instance is potentially user-defined code. # We need to catch everything and exit gracefully with an error message. @@ -204,7 +284,8 @@ def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinitio ], ) anta_log_exception(e, message, logger) - return coros + else: + yield coroutine @cprofile() @@ -246,6 +327,9 @@ async def main( # noqa: PLR0913 # Adjust the maximum number of open file descriptors for the ANTA process limits = adjust_rlimit_nofile() + # Adjust the maximum number of tests to run concurrently + max_concurrency = adjust_max_concurrency() + if not catalog.tests: logger.info("The list of tests is empty, exiting") return @@ -257,40 +341,54 @@ async def main( # noqa: PLR0913 return with Catchtime(logger=logger, message="Preparing the tests"): - selected_tests = prepare_tests(selected_inventory, catalog, tests, tags) - if selected_tests is None: + total_tests, selected_tests = setup_tests(selected_inventory, catalog, tests, tags) + if total_tests == 0 or selected_tests is None: return final_tests_count = sum(len(tests) for tests in selected_tests.values()) + generator = test_generator(selected_tests, manager) + run_info = ( - "--- ANTA NRFU Run Information ---\n" + "------------------------------------ ANTA NRFU Run Information -------------------------------------\n" f"Number of devices: {len(inventory)} ({len(selected_inventory)} established)\n" - f"Total number of selected tests: {final_tests_count}\n" + f"Total number of selected tests: {total_tests}\n" + f"Maximum number of tests to run concurrently: {max_concurrency}\n" + f"Maximum number of connections per device: {DEFAULT_MAX_CONNECTIONS}\n" f"Maximum number of open file descriptors for the current ANTA process: {limits[0]}\n" - "---------------------------------" + "----------------------------------------------------------------------------------------------------" ) logger.info(run_info) - if final_tests_count > limits[0]: + total_potential_connections = len(selected_inventory) * DEFAULT_MAX_CONNECTIONS + + if total_tests > max_concurrency: + logger.warning( + "The total number of tests is higher than the maximum number of tests to run concurrently.\n" + "ANTA will be throttled to run at the maximum number of tests to run concurrently to ensure system stability.\n" + "Please consult the ANTA FAQ." + ) + if total_potential_connections > limits[0]: logger.warning( - "The number of concurrent tests is higher than the open file descriptors limit for this ANTA process.\n" + "The total potential connections to devices is higher than the open file descriptors limit for this ANTA process.\n" "Errors may occur while running the tests.\n" "Please consult the ANTA FAQ." ) - coroutines = get_coroutines(selected_tests, manager) + # Cleanup no longer needed objects before running the tests + del selected_tests if dry_run: logger.info("Dry-run mode, exiting before running the tests.") - for coro in coroutines: - coro.close() + async for test in generator: + test.close() return if AntaTest.progress is not None: - AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coroutines)) + AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=final_tests_count) with Catchtime(logger=logger, message="Running ANTA tests"): - await asyncio.gather(*coroutines) + async for result in run(generator, limit=max_concurrency): + logger.debug(result) log_cache_statistics(selected_inventory.devices) diff --git a/tests/units/test_runner.py b/tests/units/test_runner.py index b80259cc3..ebf46c01f 100644 --- a/tests/units/test_runner.py +++ b/tests/units/test_runner.py @@ -16,7 +16,7 @@ from anta.catalog import AntaCatalog from anta.inventory import AntaInventory from anta.result_manager import ResultManager -from anta.runner import adjust_rlimit_nofile, main, prepare_tests +from anta.runner import adjust_rlimit_nofile, main, setup_tests from .test_models import FakeTest, FakeTestWithMissingTest @@ -141,13 +141,13 @@ def side_effect_setrlimit(resource_id: int, limits: tuple[int, int]) -> None: ], indirect=["inventory"], ) -async def test_prepare_tests( +async def test_setup_tests( caplog: pytest.LogCaptureFixture, inventory: AntaInventory, tags: set[str], tests: set[str], devices_count: int, tests_count: int ) -> None: - """Test the runner prepare_tests function with specific tests.""" + """Test the runner setup_tests function with specific tests.""" caplog.set_level(logging.WARNING) catalog: AntaCatalog = AntaCatalog.parse(str(DATA_DIR / "test_catalog_with_tags.yml")) - selected_tests = prepare_tests(inventory=inventory, catalog=catalog, tags=tags, tests=tests) + total_tests, selected_tests = setup_tests(inventory=inventory, catalog=catalog, tags=tags, tests=tests) if selected_tests is None: msg = f"There are no tests matching the tags {tags} to run in the current test catalog and device inventory, please verify your inputs." assert msg in caplog.messages @@ -155,6 +155,7 @@ async def test_prepare_tests( assert selected_tests is not None assert len(selected_tests) == devices_count assert sum(len(tests) for tests in selected_tests.values()) == tests_count + assert total_tests == tests_count async def test_dry_run(caplog: pytest.LogCaptureFixture, inventory: AntaInventory) -> None: