From 09bc6767f6db4768bb11b8d2cb0b4b0be1032dcf Mon Sep 17 00:00:00 2001 From: Spencer Phillip Young Date: Wed, 30 Aug 2023 14:56:58 -0700 Subject: [PATCH] extension registry cleanup --- ahk/_async/engine.py | 41 ++++++------ ahk/_async/transport.py | 6 +- ahk/_constants.py | 10 --- ahk/_sync/engine.py | 40 ++++++------ ahk/_sync/transport.py | 6 +- ahk/extensions.py | 108 ++++++++++++++++++++++---------- ahk/templates/daemon.ahk | 10 --- tests/_async/test_extensions.py | 6 +- tests/_sync/test_extensions.py | 6 +- 9 files changed, 129 insertions(+), 104 deletions(-) diff --git a/ahk/_async/engine.py b/ahk/_async/engine.py index 7d2e51c..d0b1bf0 100644 --- a/ahk/_async/engine.py +++ b/ahk/_async/engine.py @@ -35,7 +35,12 @@ else: from typing import TypeAlias -from ..extensions import Extension, _extension_method_registry, _ExtensionMethodRegistry +from ..extensions import ( + Extension, + _ExtensionMethodRegistry, + _extension_registry, + _resolve_extensions, +) from ..keys import Key from .transport import AsyncDaemonProcessTransport from .transport import AsyncFutureResult @@ -142,38 +147,28 @@ def __init__( self._extension_registry: _ExtensionMethodRegistry self._extensions: list[Extension] if extensions == 'auto': - is_async = False - is_async = True # unasync: remove - if is_async: - methods = _extension_method_registry.async_methods - else: - methods = _extension_method_registry.sync_methods - extensions = list(set(entry.extension for name, entry in methods.items())) - - self._extension_registry = _extension_method_registry - self._extensions = extensions + self._extensions = list(_extension_registry) else: - self._extensions = extensions or [] - self._extension_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={}) - for ext in self._extensions: - self._extension_registry.merge(ext._extension_method_registry) - + self._extensions = _resolve_extensions(extensions) if extensions else [] + self._method_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={}) + for ext in self._extensions: + self._method_registry.merge(ext._extension_method_registry) if TransportClass is None: TransportClass = AsyncDaemonProcessTransport assert TransportClass is not None - transport = TransportClass(executable_path=executable_path, directives=directives, extensions=extensions) + transport = TransportClass(executable_path=executable_path, directives=directives, extensions=self._extensions) self._transport: AsyncTransport = transport def __getattr__(self, name: str) -> Callable[..., Any]: is_async = False is_async = True # unasync: remove if is_async: - if name in self._extension_registry.async_methods: - method = self._extension_registry.async_methods[name].method + if name in self._method_registry.async_methods: + method = self._method_registry.async_methods[name] return partial(method, self) else: - if name in self._extension_registry.sync_methods: - method = self._extension_registry.sync_methods[name].method + if name in self._method_registry.sync_methods: + method = self._method_registry.sync_methods[name] return partial(method, self) raise AttributeError(f'{self.__class__.__name__!r} object has no attribute {name!r}') @@ -204,10 +199,12 @@ def add_hotkey( warnings.warn(warning.message, warning.category, stacklevel=2) return None - async def function_call(self, function_name: str, args: list[str], blocking: bool = True) -> Any: + async def function_call(self, function_name: str, args: list[str] | None = None, blocking: bool = True) -> Any: """ Call an AHK function defined in the daemon script. This method is intended for use by extension authors. """ + if args is None: + args = [] return await self._transport.function_call(function_name, args, blocking=blocking, engine=self) # type: ignore[call-overload] def add_hotstring( diff --git a/ahk/_async/transport.py b/ahk/_async/transport.py index 96277d3..c6b9842 100644 --- a/ahk/_async/transport.py +++ b/ahk/_async/transport.py @@ -38,7 +38,7 @@ import jinja2 -from ahk.extensions import Extension +from ahk.extensions import Extension, _resolve_includes from ahk._hotkey import ThreadedHotkeyTransport, Hotkey, Hotstring from ahk.message import RequestMessage from ahk.message import ResponseMessage @@ -689,6 +689,10 @@ def __init__( if template is None: template = self.__template self._template: jinja2.Template = template + directives = directives or [] + if extensions: + includes = _resolve_includes(extensions) + directives = includes + directives super().__init__(executable_path=executable_path, directives=directives) @property diff --git a/ahk/_constants.py b/ahk/_constants.py index baea312..f7485eb 100644 --- a/ahk/_constants.py +++ b/ahk/_constants.py @@ -6,16 +6,6 @@ #NoEnv #Persistent #SingleInstance Off -{% block extension_directives %} -; BEGIN extension includes -{% for ext in extensions %} -{% for inc in ext.includes %} -{{ inc }} - -{% endfor %} -{% endfor %} -; END extension includes -{% endblock extension_directives %} ; BEGIN user-defined directives {% block user_directives %} {% for directive in directives %} diff --git a/ahk/_sync/engine.py b/ahk/_sync/engine.py index 2d7cfd6..2240106 100644 --- a/ahk/_sync/engine.py +++ b/ahk/_sync/engine.py @@ -35,7 +35,12 @@ else: from typing import TypeAlias -from ..extensions import Extension, _extension_method_registry, _ExtensionMethodRegistry +from ..extensions import ( + Extension, + _ExtensionMethodRegistry, + _extension_registry, + _resolve_extensions, +) from ..keys import Key from .transport import DaemonProcessTransport from .transport import FutureResult @@ -138,36 +143,27 @@ def __init__( self._extension_registry: _ExtensionMethodRegistry self._extensions: list[Extension] if extensions == 'auto': - is_async = False - if is_async: - methods = _extension_method_registry.async_methods - else: - methods = _extension_method_registry.sync_methods - extensions = list(set(entry.extension for name, entry in methods.items())) - - self._extension_registry = _extension_method_registry - self._extensions = extensions + self._extensions = list(_extension_registry) else: - self._extensions = extensions or [] - self._extension_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={}) - for ext in self._extensions: - self._extension_registry.merge(ext._extension_method_registry) - + self._extensions = _resolve_extensions(extensions) if extensions else [] + self._method_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={}) + for ext in self._extensions: + self._method_registry.merge(ext._extension_method_registry) if TransportClass is None: TransportClass = DaemonProcessTransport assert TransportClass is not None - transport = TransportClass(executable_path=executable_path, directives=directives, extensions=extensions) + transport = TransportClass(executable_path=executable_path, directives=directives, extensions=self._extensions) self._transport: Transport = transport def __getattr__(self, name: str) -> Callable[..., Any]: is_async = False if is_async: - if name in self._extension_registry.async_methods: - method = self._extension_registry.async_methods[name].method + if name in self._method_registry.async_methods: + method = self._method_registry.async_methods[name] return partial(method, self) else: - if name in self._extension_registry.sync_methods: - method = self._extension_registry.sync_methods[name].method + if name in self._method_registry.sync_methods: + method = self._method_registry.sync_methods[name] return partial(method, self) raise AttributeError(f'{self.__class__.__name__!r} object has no attribute {name!r}') @@ -198,10 +194,12 @@ def add_hotkey( warnings.warn(warning.message, warning.category, stacklevel=2) return None - def function_call(self, function_name: str, args: list[str], blocking: bool = True) -> Any: + def function_call(self, function_name: str, args: list[str] | None = None, blocking: bool = True) -> Any: """ Call an AHK function defined in the daemon script. This method is intended for use by extension authors. """ + if args is None: + args = [] return self._transport.function_call(function_name, args, blocking=blocking, engine=self) # type: ignore[call-overload] def add_hotstring( diff --git a/ahk/_sync/transport.py b/ahk/_sync/transport.py index e786a19..8408566 100644 --- a/ahk/_sync/transport.py +++ b/ahk/_sync/transport.py @@ -38,7 +38,7 @@ import jinja2 -from ahk.extensions import Extension +from ahk.extensions import Extension, _resolve_includes from ahk._hotkey import ThreadedHotkeyTransport, Hotkey, Hotstring from ahk.message import RequestMessage from ahk.message import ResponseMessage @@ -662,6 +662,10 @@ def __init__( if template is None: template = self.__template self._template: jinja2.Template = template + directives = directives or [] + if extensions: + includes = _resolve_includes(extensions) + directives = includes + directives super().__init__(executable_path=executable_path, directives=directives) @property diff --git a/ahk/extensions.py b/ahk/extensions.py index 00f8ddf..ee71c9d 100644 --- a/ahk/extensions.py +++ b/ahk/extensions.py @@ -1,8 +1,11 @@ from __future__ import annotations import asyncio +import itertools import sys +import typing import warnings +from collections import deque from dataclasses import dataclass from typing import Any from typing import Callable @@ -10,8 +13,10 @@ if sys.version_info < (3, 10): from typing_extensions import ParamSpec + from typing_extensions import Concatenate else: from typing import ParamSpec + from typing import Concatenate from .directives import Include @@ -26,67 +31,64 @@ class _ExtensionEntry: P = ParamSpec('P') +if typing.TYPE_CHECKING: + from ahk import AHK, AsyncAHK + + TAHK = TypeVar('TAHK', bound=typing.Union[AHK, AsyncAHK]) + + @dataclass class _ExtensionMethodRegistry: - sync_methods: dict[str, _ExtensionEntry] - async_methods: dict[str, _ExtensionEntry] + sync_methods: dict[str, Callable[..., Any]] + async_methods: dict[str, Callable[..., Any]] - def register(self, ext: Extension, f: Callable[P, T]) -> Callable[P, T]: + def register(self, f: Callable[Concatenate[TAHK, P], T]) -> Callable[Concatenate[TAHK, P], T]: if asyncio.iscoroutinefunction(f): if f.__name__ in self.async_methods: warnings.warn( f'Method of name {f.__name__!r} has already been registered. ' - f'Previously registered method {self.async_methods[f.__name__].method!r} ' + f'Previously registered method {self.async_methods[f.__name__]!r} ' f'will be overridden by {f!r}', stacklevel=2, ) - self.async_methods[f.__name__] = _ExtensionEntry(extension=ext, method=f) + self.async_methods[f.__name__] = f else: if f.__name__ in self.sync_methods: warnings.warn( f'Method of name {f.__name__!r} has already been registered. ' - f'Previously registered method {self.sync_methods[f.__name__].method!r} ' + f'Previously registered method {self.sync_methods[f.__name__]!r} ' f'will be overridden by {f!r}', stacklevel=2, ) - self.sync_methods[f.__name__] = _ExtensionEntry(extension=ext, method=f) + self.sync_methods[f.__name__] = f return f def merge(self, other: _ExtensionMethodRegistry) -> None: - for fname, entry in other.async_methods.items(): - async_method = entry.method - if async_method.__name__ in self.async_methods: - warnings.warn( - f'Method of name {async_method.__name__!r} has already been registered. ' - f'Previously registered method {self.async_methods[async_method.__name__].method!r} ' - f'will be overridden by {async_method!r}' - ) - self.async_methods[async_method.__name__] = entry - for fname, entry in other.sync_methods.items(): - method = entry.method - if method.__name__ in self.sync_methods: - warnings.warn( - f'Method of name {method.__name__!r} has already been registered. ' - f'Previously registered method {self.sync_methods[method.__name__].method!r} ' - f'will be overridden by {method!r}' - ) - self.sync_methods[method.__name__] = entry + for name, method in other.methods: + self.register(method) + + @property + def methods(self) -> list[tuple[str, Callable[..., Any]]]: + return list(itertools.chain(self.async_methods.items(), self.sync_methods.items())) -_extension_method_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={}) +_extension_registry: dict[Extension, _ExtensionMethodRegistry] = {} class Extension: def __init__( self, - includes: list[str] | None = None, script_text: str | None = None, - # template: str | Template | None = None + includes: list[str] | None = None, + dependencies: list[Extension] | None = None, ): self._text: str = script_text or '' - # self._template: str | Template | None = template self._includes: list[str] = includes or [] - self._extension_method_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={}) + self.dependencies: list[Extension] = dependencies or [] + self._extension_method_registry: _ExtensionMethodRegistry = _ExtensionMethodRegistry( + sync_methods={}, async_methods={} + ) + _extension_registry[self] = self._extension_method_registry @property def script_text(self) -> str: @@ -100,7 +102,47 @@ def script_text(self, new_script: str) -> None: def includes(self) -> list[Include]: return [Include(inc) for inc in self._includes] - def register(self, f: Callable[P, T]) -> Callable[P, T]: - self._extension_method_registry.register(self, f) - _extension_method_registry.register(self, f) + def register(self, f: Callable[Concatenate[TAHK, P], T]) -> Callable[Concatenate[TAHK, P], T]: + self._extension_method_registry.register(f) return f + + def __hash__(self) -> int: + return hash((self._text, tuple(self.includes), tuple(self.dependencies))) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Extension): + return hash(self) == hash(other) + return NotImplemented + + +def _resolve_extension(extension: Extension, seen: set[Extension]) -> list[Extension]: + ret: deque[Extension] = deque() + todo = [extension] + while todo: + ext = todo.pop() + if ext in seen: + continue + ret.appendleft(ext) + seen.add(ext) + todo.extend(ext.dependencies) + return list(ret) + + +def _resolve_extensions(extensions: list[Extension]) -> list[Extension]: + seen: set[Extension] = set() + ret: list[Extension] = [] + for ext in extensions: + ret.extend(_resolve_extension(ext, seen=seen)) + return ret + + +def _resolve_includes(extensions: list[Extension]) -> list[Include]: + extensions = _resolve_extensions(extensions) + ret = [] + seen: set[Include] = set() + for ext in extensions: + for include in ext.includes: + if include in seen: + continue + ret.append(include) + return ret diff --git a/ahk/templates/daemon.ahk b/ahk/templates/daemon.ahk index 32f136c..112a373 100644 --- a/ahk/templates/daemon.ahk +++ b/ahk/templates/daemon.ahk @@ -3,16 +3,6 @@ #NoEnv #Persistent #SingleInstance Off -{% block extension_directives %} -; BEGIN extension includes -{% for ext in extensions %} -{% for inc in ext.includes %} -{{ inc }} - -{% endfor %} -{% endfor %} -; END extension includes -{% endblock extension_directives %} ; BEGIN user-defined directives {% block user_directives %} {% for directive in directives %} diff --git a/tests/_async/test_extensions.py b/tests/_async/test_extensions.py index 1842a2e..e76ebfe 100644 --- a/tests/_async/test_extensions.py +++ b/tests/_async/test_extensions.py @@ -35,7 +35,7 @@ async def asyncTearDown(self) -> None: self.ahk._transport._proc.kill() time.sleep(0.2) - async def test_ext(self): + async def test_ext_explicit(self): res = await self.ahk.do_something('foo') assert res == 'testfoo' @@ -48,7 +48,7 @@ async def asyncTearDown(self) -> None: self.ahk._transport._proc.kill() time.sleep(0.2) - async def test_ext(self): + async def test_ext_auto(self): res = await self.ahk.do_something('foo') assert res == 'testfoo' @@ -62,5 +62,5 @@ async def asyncTearDown(self) -> None: self.ahk._transport._proc.kill() time.sleep(0.2) - async def test_ext(self): + async def test_ext_no_ext(self): assert not hasattr(self.ahk, 'do_something') diff --git a/tests/_sync/test_extensions.py b/tests/_sync/test_extensions.py index 8eadcb2..b7c8d14 100644 --- a/tests/_sync/test_extensions.py +++ b/tests/_sync/test_extensions.py @@ -34,7 +34,7 @@ def tearDown(self) -> None: self.ahk._transport._proc.kill() time.sleep(0.2) - def test_ext(self): + def test_ext_explicit(self): res = self.ahk.do_something('foo') assert res == 'testfoo' @@ -47,7 +47,7 @@ def tearDown(self) -> None: self.ahk._transport._proc.kill() time.sleep(0.2) - def test_ext(self): + def test_ext_auto(self): res = self.ahk.do_something('foo') assert res == 'testfoo' @@ -61,5 +61,5 @@ def tearDown(self) -> None: self.ahk._transport._proc.kill() time.sleep(0.2) - def test_ext(self): + def test_ext_no_ext(self): assert not hasattr(self.ahk, 'do_something')