Skip to content

Commit

Permalink
feat: annotate plugin.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kdmccormick committed Jun 12, 2024
1 parent 153348f commit 1db2c3d
Showing 1 changed file with 27 additions and 20 deletions.
47 changes: 27 additions & 20 deletions xblock/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
This code is in the Runtime layer.
"""
from __future__ import annotations

import functools
import itertools
import logging
import pkg_resources
import typing as t

from pkg_resources import iter_entry_points, EntryPoint

from xblock.internal import class_lazy

log = logging.getLogger(__name__)

PLUGIN_CACHE = {}
PLUGIN_CACHE: dict[tuple[str, str], type[Plugin]] = {}


class PluginMissingError(Exception):
Expand All @@ -21,35 +25,33 @@ class PluginMissingError(Exception):

class AmbiguousPluginError(Exception):
"""Raised when a class name produces more than one entry_point."""
def __init__(self, all_entry_points):
def __init__(self, all_entry_points: list[EntryPoint]):
classes = (entpt.load() for entpt in all_entry_points)
desc = ", ".join("{0.__module__}.{0.__name__}".format(cls) for cls in classes)
msg = f"Ambiguous entry points for {all_entry_points[0].name}: {desc}"
super().__init__(msg)


def default_select(identifier, all_entry_points): # pylint: disable=inconsistent-return-statements
def default_select(identifier: str, all_entry_points: list[EntryPoint]) -> EntryPoint:
"""
Raise an exception when we have ambiguous entry points.
Raise an exception when we have no entry points or ambiguous entry points.
"""

if len(all_entry_points) == 0:
if not all_entry_points:
raise PluginMissingError(identifier)
if len(all_entry_points) == 1:
return all_entry_points[0]
elif len(all_entry_points) > 1:
raise AmbiguousPluginError(all_entry_points)
return all_entry_points[0]


class Plugin:
"""Base class for a system that uses entry_points to load plugins.
"""
Base class for a system that uses entry_points to load plugins.
Implementing classes are expected to have the following attributes:
`entry_point`: The name of the entry point to load plugins from.
"""
entry_point = None # Should be overwritten by children classes
entry_point: str # Should be overwritten by children classes

@class_lazy
def extra_entry_points(cls): # pylint: disable=no-self-argument
Expand All @@ -62,7 +64,7 @@ def extra_entry_points(cls): # pylint: disable=no-self-argument
return []

@classmethod
def _load_class_entry_point(cls, entry_point):
def _load_class_entry_point(cls, entry_point: EntryPoint) -> type[t.Self]:
"""
Load `entry_point`, and set the `entry_point.name` as the
attribute `plugin_name` on the loaded object
Expand All @@ -72,7 +74,12 @@ def _load_class_entry_point(cls, entry_point):
return class_

@classmethod
def load_class(cls, identifier, default=None, select=None):
def load_class(
cls,
identifier: str,
default: type[t.Self] | None = None,
select: t.Callable[[str, list[EntryPoint]], EntryPoint] | None = None,
) -> type[t.Self]:
"""Load a single class specified by identifier.
If `identifier` specifies more than a single class, and `select` is not None,
Expand Down Expand Up @@ -100,7 +107,7 @@ def select(identifier, all_entry_points):
if select is None:
select = default_select

all_entry_points = list(pkg_resources.iter_entry_points(cls.entry_point, name=identifier))
all_entry_points = list(iter_entry_points(cls.entry_point, name=identifier))
for extra_identifier, extra_entry_point in iter(cls.extra_entry_points):
if identifier == extra_identifier:
all_entry_points.append(extra_entry_point)
Expand All @@ -117,7 +124,7 @@ def select(identifier, all_entry_points):
return PLUGIN_CACHE[key]

@classmethod
def load_classes(cls, fail_silently=True):
def load_classes(cls, fail_silently: bool = True) -> t.Iterable[tuple[str, type[t.Self]]]:
"""Load all the classes for a plugin.
Produces a sequence containing the identifiers and their corresponding
Expand All @@ -133,7 +140,7 @@ def load_classes(cls, fail_silently=True):
contexts. Hence, the flag.
"""
all_classes = itertools.chain(
pkg_resources.iter_entry_points(cls.entry_point),
iter_entry_points(cls.entry_point),
(entry_point for identifier, entry_point in iter(cls.extra_entry_points)),
)
for class_ in all_classes:
Expand All @@ -146,15 +153,15 @@ def load_classes(cls, fail_silently=True):
raise

@classmethod
def register_temp_plugin(cls, class_, identifier=None, dist='xblock'):
"""Decorate a function to run with a temporary plugin available.
def register_temp_plugin(cls, class_: type, identifier: str | None = None, dist: str = 'xblock'):
"""
Decorate a function to run with a temporary plugin available.
Use it like this in tests::
@register_temp_plugin(MyXBlockClass):
def test_the_thing():
# Here I can load MyXBlockClass by name.
"""
from unittest.mock import Mock # pylint: disable=import-outside-toplevel

Expand Down

0 comments on commit 1db2c3d

Please sign in to comment.