diff --git a/colcon_core/extension_point.py b/colcon_core/extension_point.py index c724d2213..16fc340d3 100644 --- a/colcon_core/extension_point.py +++ b/colcon_core/extension_point.py @@ -4,6 +4,7 @@ from collections import defaultdict import os +import sys import traceback try: @@ -26,7 +27,6 @@ logger = colcon_logger.getChild(__name__) - """ The group name for entry points identifying colcon extension points. @@ -36,6 +36,8 @@ """ EXTENSION_POINT_GROUP_NAME = 'colcon_core.extension_point' +_ENTRY_POINTS_CACHE = [] + def _get_unique_distributions(): seen = set() @@ -46,6 +48,43 @@ def _get_unique_distributions(): yield dist +def _get_entry_points(): + for dist in _get_unique_distributions(): + for entry_point in dist.entry_points: + # Modern EntryPoint instances should already have this set + if not hasattr(entry_point, 'dist'): + entry_point.dist = dist + yield entry_point + + +def _get_cached_entry_points(): + if not _ENTRY_POINTS_CACHE: + if sys.version_info >= (3, 10): + # We prefer using importlib.metadata.entry_points because it + # has an internal optimization which allows us to load the entry + # points without reading the individual PKG-INFO files, while + # still visiting each unique distribution only once. + if sys.version_info >= (3, 12): + _ENTRY_POINTS_CACHE.extend(entry_points()) + else: + # Prior to Python 3.12, entry_points returned a (deprecated) + # dict. Unfortunately, the "future-proof" recommended + # pattern is to add filter parameters, but we actually + # want to cache everything so that doesn't work here. + _ENTRY_POINTS_CACHE.extend(entry_points().values()) + else: + # If we don't have Python 3.10, we must read each PKG-INFO to + # get the name of the distribution so that we can skip the + # "shadowed" distributions properly. + _ENTRY_POINTS_CACHE.extend(_get_entry_points()) + return _ENTRY_POINTS_CACHE + + +def clear_entry_point_cache(): + """Purge the entry point cache.""" + _ENTRY_POINTS_CACHE.clear() + + def get_all_extension_points(): """ Get all extension points related to `colcon` and any of its extensions. @@ -59,23 +98,24 @@ def get_all_extension_points(): colcon_extension_points = get_extension_points(EXTENSION_POINT_GROUP_NAME) colcon_extension_points.setdefault(EXTENSION_POINT_GROUP_NAME, None) - entry_points = defaultdict(dict) - for dist in _get_unique_distributions(): - for entry_point in dist.entry_points: - # skip groups which are not registered as extension points - if entry_point.group not in colcon_extension_points: - continue - - if entry_point.name in entry_points[entry_point.group]: - previous = entry_points[entry_point.group][entry_point.name] - logger.error( - f"Entry point '{entry_point.group}.{entry_point.name}' is " - f"declared multiple times, '{entry_point.value}' " - f"from '{dist._path}' " - f"overwriting '{previous}'") - entry_points[entry_point.group][entry_point.name] = \ - (entry_point.value, dist.metadata['Name'], dist.version) - return entry_points + extension_points = defaultdict(dict) + for entry_point in _get_cached_entry_points(): + if entry_point.group not in colcon_extension_points: + continue + + dist_metadata = entry_point.dist.metadata + ep_tuple = ( + entry_point.value, + dist_metadata['Name'], dist_metadata['Version'], + ) + if entry_point.name in extension_points[entry_point.group]: + previous = extension_points[entry_point.group][entry_point.name] + logger.error( + f"Entry point '{entry_point.group}.{entry_point.name}' is " + f"declared multiple times, '{ep_tuple}' " + f"overwriting '{previous}'") + extension_points[entry_point.group][entry_point.name] = ep_tuple + return extension_points def get_extension_points(group): @@ -87,16 +127,9 @@ def get_extension_points(group): :rtype: dict """ extension_points = {} - try: - # Python 3.10 and newer - query = entry_points(group=group) - except TypeError: - query = ( - entry_point - for dist in _get_unique_distributions() - for entry_point in dist.entry_points - if entry_point.group == group) - for entry_point in query: + for entry_point in _get_cached_entry_points(): + if entry_point.group != group: + continue if entry_point.name in extension_points: previous_entry_point = extension_points[entry_point.name] logger.error( diff --git a/test/test_extension_point.py b/test/test_extension_point.py index 6d6269619..d5cdba434 100644 --- a/test/test_extension_point.py +++ b/test/test_extension_point.py @@ -12,6 +12,7 @@ # TODO: Drop this with Python 3.7 support from importlib_metadata import Distribution +from colcon_core.extension_point import clear_entry_point_cache from colcon_core.extension_point import EntryPoint from colcon_core.extension_point import EXTENSION_POINT_GROUP_NAME from colcon_core.extension_point import get_all_extension_points @@ -73,6 +74,8 @@ def test_all_extension_points(): 'colcon_core.extension_point.distributions', side_effect=_distributions ): + clear_entry_point_cache() + # successfully load a known entry point extension_points = get_all_extension_points() assert set(extension_points.keys()) == { @@ -94,12 +97,14 @@ def test_extension_point_blocklist(): 'colcon_core.extension_point.distributions', side_effect=_distributions ): + clear_entry_point_cache() extension_points = get_extension_points('group1') assert 'extA' in extension_points.keys() extension_point = extension_points['extA'] assert extension_point == 'eA' with patch.object(EntryPoint, 'load', return_value=None) as load: + clear_entry_point_cache() load_extension_point('extA', 'eA', 'group1') assert load.call_count == 1 @@ -108,12 +113,14 @@ def test_extension_point_blocklist(): with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST=os.pathsep.join([ 'group1.extB', 'group2.extC']) ): + clear_entry_point_cache() load_extension_point('extA', 'eA', 'group1') assert load.call_count == 1 # entry point in a blocked group can't be loaded load.reset_mock() with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST='group1'): + clear_entry_point_cache() with pytest.raises(RuntimeError) as e: load_extension_point('extA', 'eA', 'group1') assert 'The entry point group name is listed in the environment ' \ @@ -124,6 +131,7 @@ def test_extension_point_blocklist(): with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST=os.pathsep.join([ 'group1.extA', 'group1.extB']) ): + clear_entry_point_cache() with pytest.raises(RuntimeError) as e: load_extension_point('extA', 'eA', 'group1') assert 'The entry point name is listed in the environment ' \