diff --git a/colcon_core/extension_point.py b/colcon_core/extension_point.py index ff348ba5..c724d221 100644 --- a/colcon_core/extension_point.py +++ b/colcon_core/extension_point.py @@ -37,6 +37,15 @@ EXTENSION_POINT_GROUP_NAME = 'colcon_core.extension_point' +def _get_unique_distributions(): + seen = set() + for dist in distributions(): + dist_name = dist.metadata['Name'] + if dist_name not in seen: + seen.add(dist_name) + yield dist + + def get_all_extension_points(): """ Get all extension points related to `colcon` and any of its extensions. @@ -51,12 +60,7 @@ def get_all_extension_points(): colcon_extension_points.setdefault(EXTENSION_POINT_GROUP_NAME, None) entry_points = defaultdict(dict) - seen = set() - for dist in distributions(): - dist_name = dist.metadata['Name'] - if dist_name in seen: - continue - seen.add(dist_name) + for dist in _get_unique_distributions(): for entry_point in dist.entry_points: # skip groups which are not registered as extension points if entry_point.group not in colcon_extension_points: @@ -70,7 +74,7 @@ def get_all_extension_points(): f"from '{dist._path}' " f"overwriting '{previous}'") entry_points[entry_point.group][entry_point.name] = \ - (entry_point.value, dist_name, dist.version) + (entry_point.value, dist.metadata['Name'], dist.version) return entry_points @@ -87,7 +91,11 @@ def get_extension_points(group): # Python 3.10 and newer query = entry_points(group=group) except TypeError: - query = entry_points().get(group, ()) + query = ( + entry_point + for dist in _get_unique_distributions() + for entry_point in dist.entry_points + if entry_point.group == group) for entry_point in query: if entry_point.name in extension_points: previous_entry_point = extension_points[entry_point.name] diff --git a/test/test_extension_point.py b/test/test_extension_point.py index 7111b796..96e58a0d 100644 --- a/test/test_extension_point.py +++ b/test/test_extension_point.py @@ -54,8 +54,8 @@ def iter_entry_points(*, group=None): def distributions(): return [ - Dist(iter_entry_points(group='group1')), - Dist([EntryPoint('extC', 'eC', Group2.name)]), + Dist([Group1, ExtA, ExtB]), + Dist([Group2, EntryPoint('extC', 'eC', Group2.name)]), Dist([EntryPoint('extD', 'eD', 'groupX')]), ] @@ -71,7 +71,11 @@ def test_all_extension_points(): ): # successfully load a known entry point extension_points = get_all_extension_points() - assert set(extension_points.keys()) == {'group1', 'group2'} + assert set(extension_points.keys()) == { + EXTENSION_POINT_GROUP_NAME, + 'group1', + 'group2', + } assert set(extension_points['group1'].keys()) == {'extA', 'extB'} assert extension_points['group1']['extA'][0] == 'eA'