Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix entry point discovery on Python < 3.10 #604

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions colcon_core/extension_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@
EXTENSION_POINT_GROUP_NAME = 'colcon_core.extension_point'


def _get_unique_distributions():
seen = set()
for dist in distributions():
dist_name = dist.metadata['Name']
if dist_name not in seen:
seen.add(dist_name)
yield dist


def get_all_extension_points():
"""
Get all extension points related to `colcon` and any of its extensions.
Expand All @@ -51,12 +60,7 @@ def get_all_extension_points():
colcon_extension_points.setdefault(EXTENSION_POINT_GROUP_NAME, None)

entry_points = defaultdict(dict)
seen = set()
for dist in distributions():
dist_name = dist.metadata['Name']
if dist_name in seen:
continue
seen.add(dist_name)
for dist in _get_unique_distributions():
for entry_point in dist.entry_points:
# skip groups which are not registered as extension points
if entry_point.group not in colcon_extension_points:
Expand All @@ -70,7 +74,7 @@ def get_all_extension_points():
f"from '{dist._path}' "
f"overwriting '{previous}'")
entry_points[entry_point.group][entry_point.name] = \
(entry_point.value, dist_name, dist.version)
(entry_point.value, dist.metadata['Name'], dist.version)
return entry_points


Expand All @@ -87,7 +91,11 @@ def get_extension_points(group):
# Python 3.10 and newer
query = entry_points(group=group)
except TypeError:
query = entry_points().get(group, ())
query = (
entry_point
for dist in _get_unique_distributions()
for entry_point in dist.entry_points
if entry_point.group == group)
for entry_point in query:
if entry_point.name in extension_points:
previous_entry_point = extension_points[entry_point.name]
Expand Down
10 changes: 7 additions & 3 deletions test/test_extension_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def iter_entry_points(*, group=None):

def distributions():
return [
Dist(iter_entry_points(group='group1')),
Dist([EntryPoint('extC', 'eC', Group2.name)]),
Dist([Group1, ExtA, ExtB]),
Dist([Group2, EntryPoint('extC', 'eC', Group2.name)]),
Dist([EntryPoint('extD', 'eD', 'groupX')]),
]

Expand All @@ -71,7 +71,11 @@ def test_all_extension_points():
):
# successfully load a known entry point
extension_points = get_all_extension_points()
assert set(extension_points.keys()) == {'group1', 'group2'}
assert set(extension_points.keys()) == {
EXTENSION_POINT_GROUP_NAME,
'group1',
'group2',
}
assert set(extension_points['group1'].keys()) == {'extA', 'extB'}
assert extension_points['group1']['extA'][0] == 'eA'

Expand Down