diff --git a/colcon_core/command.py b/colcon_core/command.py index bad98e22..10668edc 100644 --- a/colcon_core/command.py +++ b/colcon_core/command.py @@ -90,7 +90,10 @@ def register_command_exit_handler(handler): _command_exit_handlers.append(handler) -def main(*, command_name='colcon', argv=None): +def main( + *, command_name='colcon', argv=None, verb_group_name=None, + environment_variable_group_name=None, +): """ Execute the main logic of the command. @@ -113,7 +116,10 @@ def main(*, command_name='colcon', argv=None): :returns: The return code """ try: - return _main(command_name=command_name, argv=argv) + return _main( + command_name=command_name, argv=argv, + verb_group_name=verb_group_name, + environment_variable_group_name=environment_variable_group_name) except KeyboardInterrupt: return signal.SIGINT finally: @@ -123,7 +129,9 @@ def main(*, command_name='colcon', argv=None): handler() -def _main(*, command_name, argv): +def _main( + *, command_name, argv, verb_group_name, environment_variable_group_name, +): # default log level, for searchability: COLCON_LOG_LEVEL colcon_logger.setLevel(logging.WARNING) set_logger_level_from_env( @@ -137,9 +145,9 @@ def _main(*, command_name, argv): path=(Path('~') / f'.{command_name}').expanduser(), env_var=f'{command_name}_HOME'.upper()) - parser = create_parser('colcon_core.environment_variable') + parser = create_parser(environment_variable_group_name) - verb_extensions = get_verb_extensions() + verb_extensions = get_verb_extensions(group_name=verb_group_name) # add subparsers for all verb extensions but without arguments for now subparser = create_subparser( @@ -203,7 +211,7 @@ def _main(*, command_name, argv): return verb_main(context, colcon_logger) -def create_parser(environment_variables_group_name): +def create_parser(environment_variables_group_name=None): """ Create the argument parser. @@ -283,7 +291,7 @@ def _split_lines(self, text, width): return lines -def get_environment_variables_epilog(group_name): +def get_environment_variables_epilog(group_name=None): """ Get a message enumerating the registered environment variables. @@ -292,6 +300,8 @@ def get_environment_variables_epilog(group_name): :returns: The message for the argument parser epilog :rtype: str """ + if group_name is None: + group_name = 'colcon_core.environment_variable' # list environment variables with descriptions entry_points = load_extension_points(group_name) if not entry_points: diff --git a/colcon_core/extension_point.py b/colcon_core/extension_point.py index 4bea3fc6..e04ee0fd 100644 --- a/colcon_core/extension_point.py +++ b/colcon_core/extension_point.py @@ -22,11 +22,14 @@ from colcon_core.environment_variable import EnvironmentVariable from colcon_core.logging import colcon_logger -"""Environment variable to block extensions""" -EXTENSION_BLOCKLIST_ENVIRONMENT_VARIABLE = EnvironmentVariable( +_EXTENSION_BLOCKLIST_ENVIRONMENT_VARIABLE = EnvironmentVariable( 'COLCON_EXTENSION_BLOCKLIST', 'Block extensions which should not be used') +"""Environment variable to block extensions""" +EXTENSION_BLOCKLIST_ENVIRONMENT_VARIABLE = \ + _EXTENSION_BLOCKLIST_ENVIRONMENT_VARIABLE + logger = colcon_logger.getChild(__name__) """ @@ -205,3 +208,16 @@ def load_extension_point(name, value, group): 'The entry point name is listed in the environment variable ' f"'{EXTENSION_BLOCKLIST_ENVIRONMENT_VARIABLE.name}'") return EntryPoint(name, value, group).load() + + +def override_blocklist_variable(variable): + """ + Override the blocklist environment variable. + + :param EnvironmentVariable variable: The new blocklist environment + variable, or None to reset to default. + """ + if variable is None: + variable = _EXTENSION_BLOCKLIST_ENVIRONMENT_VARIABLE + global EXTENSION_BLOCKLIST_ENVIRONMENT_VARIABLE + EXTENSION_BLOCKLIST_ENVIRONMENT_VARIABLE = variable diff --git a/test/test_extension_point.py b/test/test_extension_point.py index 63f89edb..f0fa8043 100644 --- a/test/test_extension_point.py +++ b/test/test_extension_point.py @@ -12,6 +12,7 @@ # TODO: Drop this with Python 3.7 support from importlib_metadata import Distribution +from colcon_core.environment_variable import EnvironmentVariable from colcon_core.extension_point import clear_entry_point_cache from colcon_core.extension_point import EntryPoint from colcon_core.extension_point import EXTENSION_POINT_GROUP_NAME @@ -19,6 +20,7 @@ from colcon_core.extension_point import get_extension_points from colcon_core.extension_point import load_extension_point from colcon_core.extension_point import load_extension_points +from colcon_core.extension_point import override_blocklist_variable import pytest from .environment_context import EnvironmentContext @@ -139,6 +141,45 @@ def test_extension_point_blocklist(): assert load.call_count == 0 +def test_extension_point_blocklist_override(): + with patch.object(EntryPoint, 'load', return_value=None) as load: + clear_entry_point_cache() + + my_extension_blocklist = EnvironmentVariable( + 'MY_EXTENSION_BLOCKLIST', 'Foo bar baz') + override_blocklist_variable(my_extension_blocklist) + + try: + # entry point in default blocklist variable can be loaded + load.reset_mock() + with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST='group1'): + clear_entry_point_cache() + load_extension_point('extA', 'eA', 'group1') + assert load.call_count == 1 + + # entry point in custom blocklist variable can't be loaded + load.reset_mock() + with EnvironmentContext(MY_EXTENSION_BLOCKLIST='group1'): + clear_entry_point_cache() + with pytest.raises(RuntimeError) as e: + load_extension_point('extA', 'eA', 'group1') + assert 'The entry point group name is listed in the ' \ + 'environment variable' in str(e.value) + assert load.call_count == 0 + finally: + override_blocklist_variable(None) + + # entry point in default blocklist variable can no longer be loaded + load.reset_mock() + with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST='group1'): + clear_entry_point_cache() + with pytest.raises(RuntimeError) as e: + load_extension_point('extA', 'eA', 'group1') + assert 'The entry point group name is listed in the ' \ + 'environment variable' in str(e.value) + assert load.call_count == 0 + + def test_redefined_extension_point(): def _duped_distributions(): yield from _distributions()