Skip to content

Commit

Permalink
extension registry cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
spyoungtech committed Aug 30, 2023
1 parent 92c0c77 commit 09bc676
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 104 deletions.
41 changes: 19 additions & 22 deletions ahk/_async/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
else:
from typing import TypeAlias

from ..extensions import Extension, _extension_method_registry, _ExtensionMethodRegistry
from ..extensions import (
Extension,
_ExtensionMethodRegistry,
_extension_registry,
_resolve_extensions,
)
from ..keys import Key
from .transport import AsyncDaemonProcessTransport
from .transport import AsyncFutureResult
Expand Down Expand Up @@ -142,38 +147,28 @@ def __init__(
self._extension_registry: _ExtensionMethodRegistry
self._extensions: list[Extension]
if extensions == 'auto':
is_async = False
is_async = True # unasync: remove
if is_async:
methods = _extension_method_registry.async_methods
else:
methods = _extension_method_registry.sync_methods
extensions = list(set(entry.extension for name, entry in methods.items()))

self._extension_registry = _extension_method_registry
self._extensions = extensions
self._extensions = list(_extension_registry)
else:
self._extensions = extensions or []
self._extension_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})
for ext in self._extensions:
self._extension_registry.merge(ext._extension_method_registry)

self._extensions = _resolve_extensions(extensions) if extensions else []
self._method_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})
for ext in self._extensions:
self._method_registry.merge(ext._extension_method_registry)
if TransportClass is None:
TransportClass = AsyncDaemonProcessTransport
assert TransportClass is not None
transport = TransportClass(executable_path=executable_path, directives=directives, extensions=extensions)
transport = TransportClass(executable_path=executable_path, directives=directives, extensions=self._extensions)
self._transport: AsyncTransport = transport

def __getattr__(self, name: str) -> Callable[..., Any]:
is_async = False
is_async = True # unasync: remove
if is_async:
if name in self._extension_registry.async_methods:
method = self._extension_registry.async_methods[name].method
if name in self._method_registry.async_methods:
method = self._method_registry.async_methods[name]
return partial(method, self)
else:
if name in self._extension_registry.sync_methods:
method = self._extension_registry.sync_methods[name].method
if name in self._method_registry.sync_methods:
method = self._method_registry.sync_methods[name]
return partial(method, self)

raise AttributeError(f'{self.__class__.__name__!r} object has no attribute {name!r}')
Expand Down Expand Up @@ -204,10 +199,12 @@ def add_hotkey(
warnings.warn(warning.message, warning.category, stacklevel=2)
return None

async def function_call(self, function_name: str, args: list[str], blocking: bool = True) -> Any:
async def function_call(self, function_name: str, args: list[str] | None = None, blocking: bool = True) -> Any:
"""
Call an AHK function defined in the daemon script. This method is intended for use by extension authors.
"""
if args is None:
args = []
return await self._transport.function_call(function_name, args, blocking=blocking, engine=self) # type: ignore[call-overload]

def add_hotstring(
Expand Down
6 changes: 5 additions & 1 deletion ahk/_async/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

import jinja2

from ahk.extensions import Extension
from ahk.extensions import Extension, _resolve_includes
from ahk._hotkey import ThreadedHotkeyTransport, Hotkey, Hotstring
from ahk.message import RequestMessage
from ahk.message import ResponseMessage
Expand Down Expand Up @@ -689,6 +689,10 @@ def __init__(
if template is None:
template = self.__template
self._template: jinja2.Template = template
directives = directives or []
if extensions:
includes = _resolve_includes(extensions)
directives = includes + directives
super().__init__(executable_path=executable_path, directives=directives)

@property
Expand Down
10 changes: 0 additions & 10 deletions ahk/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,6 @@
#NoEnv
#Persistent
#SingleInstance Off
{% block extension_directives %}
; BEGIN extension includes
{% for ext in extensions %}
{% for inc in ext.includes %}
{{ inc }}
{% endfor %}
{% endfor %}
; END extension includes
{% endblock extension_directives %}
; BEGIN user-defined directives
{% block user_directives %}
{% for directive in directives %}
Expand Down
40 changes: 19 additions & 21 deletions ahk/_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
else:
from typing import TypeAlias

from ..extensions import Extension, _extension_method_registry, _ExtensionMethodRegistry
from ..extensions import (
Extension,
_ExtensionMethodRegistry,
_extension_registry,
_resolve_extensions,
)
from ..keys import Key
from .transport import DaemonProcessTransport
from .transport import FutureResult
Expand Down Expand Up @@ -138,36 +143,27 @@ def __init__(
self._extension_registry: _ExtensionMethodRegistry
self._extensions: list[Extension]
if extensions == 'auto':
is_async = False
if is_async:
methods = _extension_method_registry.async_methods
else:
methods = _extension_method_registry.sync_methods
extensions = list(set(entry.extension for name, entry in methods.items()))

self._extension_registry = _extension_method_registry
self._extensions = extensions
self._extensions = list(_extension_registry)
else:
self._extensions = extensions or []
self._extension_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})
for ext in self._extensions:
self._extension_registry.merge(ext._extension_method_registry)

self._extensions = _resolve_extensions(extensions) if extensions else []
self._method_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})
for ext in self._extensions:
self._method_registry.merge(ext._extension_method_registry)
if TransportClass is None:
TransportClass = DaemonProcessTransport
assert TransportClass is not None
transport = TransportClass(executable_path=executable_path, directives=directives, extensions=extensions)
transport = TransportClass(executable_path=executable_path, directives=directives, extensions=self._extensions)
self._transport: Transport = transport

def __getattr__(self, name: str) -> Callable[..., Any]:
is_async = False
if is_async:
if name in self._extension_registry.async_methods:
method = self._extension_registry.async_methods[name].method
if name in self._method_registry.async_methods:
method = self._method_registry.async_methods[name]
return partial(method, self)
else:
if name in self._extension_registry.sync_methods:
method = self._extension_registry.sync_methods[name].method
if name in self._method_registry.sync_methods:
method = self._method_registry.sync_methods[name]
return partial(method, self)

raise AttributeError(f'{self.__class__.__name__!r} object has no attribute {name!r}')
Expand Down Expand Up @@ -198,10 +194,12 @@ def add_hotkey(
warnings.warn(warning.message, warning.category, stacklevel=2)
return None

def function_call(self, function_name: str, args: list[str], blocking: bool = True) -> Any:
def function_call(self, function_name: str, args: list[str] | None = None, blocking: bool = True) -> Any:
"""
Call an AHK function defined in the daemon script. This method is intended for use by extension authors.
"""
if args is None:
args = []
return self._transport.function_call(function_name, args, blocking=blocking, engine=self) # type: ignore[call-overload]

def add_hotstring(
Expand Down
6 changes: 5 additions & 1 deletion ahk/_sync/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

import jinja2

from ahk.extensions import Extension
from ahk.extensions import Extension, _resolve_includes
from ahk._hotkey import ThreadedHotkeyTransport, Hotkey, Hotstring
from ahk.message import RequestMessage
from ahk.message import ResponseMessage
Expand Down Expand Up @@ -662,6 +662,10 @@ def __init__(
if template is None:
template = self.__template
self._template: jinja2.Template = template
directives = directives or []
if extensions:
includes = _resolve_includes(extensions)
directives = includes + directives
super().__init__(executable_path=executable_path, directives=directives)

@property
Expand Down
108 changes: 75 additions & 33 deletions ahk/extensions.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from __future__ import annotations

import asyncio
import itertools
import sys
import typing
import warnings
from collections import deque
from dataclasses import dataclass
from typing import Any
from typing import Callable
from typing import TypeVar

if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
from typing_extensions import Concatenate
else:
from typing import ParamSpec
from typing import Concatenate

from .directives import Include

Expand All @@ -26,67 +31,64 @@ class _ExtensionEntry:
P = ParamSpec('P')


if typing.TYPE_CHECKING:
from ahk import AHK, AsyncAHK

TAHK = TypeVar('TAHK', bound=typing.Union[AHK, AsyncAHK])


@dataclass
class _ExtensionMethodRegistry:
sync_methods: dict[str, _ExtensionEntry]
async_methods: dict[str, _ExtensionEntry]
sync_methods: dict[str, Callable[..., Any]]
async_methods: dict[str, Callable[..., Any]]

def register(self, ext: Extension, f: Callable[P, T]) -> Callable[P, T]:
def register(self, f: Callable[Concatenate[TAHK, P], T]) -> Callable[Concatenate[TAHK, P], T]:
if asyncio.iscoroutinefunction(f):
if f.__name__ in self.async_methods:
warnings.warn(
f'Method of name {f.__name__!r} has already been registered. '
f'Previously registered method {self.async_methods[f.__name__].method!r} '
f'Previously registered method {self.async_methods[f.__name__]!r} '
f'will be overridden by {f!r}',
stacklevel=2,
)
self.async_methods[f.__name__] = _ExtensionEntry(extension=ext, method=f)
self.async_methods[f.__name__] = f
else:
if f.__name__ in self.sync_methods:
warnings.warn(
f'Method of name {f.__name__!r} has already been registered. '
f'Previously registered method {self.sync_methods[f.__name__].method!r} '
f'Previously registered method {self.sync_methods[f.__name__]!r} '
f'will be overridden by {f!r}',
stacklevel=2,
)
self.sync_methods[f.__name__] = _ExtensionEntry(extension=ext, method=f)
self.sync_methods[f.__name__] = f
return f

def merge(self, other: _ExtensionMethodRegistry) -> None:
for fname, entry in other.async_methods.items():
async_method = entry.method
if async_method.__name__ in self.async_methods:
warnings.warn(
f'Method of name {async_method.__name__!r} has already been registered. '
f'Previously registered method {self.async_methods[async_method.__name__].method!r} '
f'will be overridden by {async_method!r}'
)
self.async_methods[async_method.__name__] = entry
for fname, entry in other.sync_methods.items():
method = entry.method
if method.__name__ in self.sync_methods:
warnings.warn(
f'Method of name {method.__name__!r} has already been registered. '
f'Previously registered method {self.sync_methods[method.__name__].method!r} '
f'will be overridden by {method!r}'
)
self.sync_methods[method.__name__] = entry
for name, method in other.methods:
self.register(method)

@property
def methods(self) -> list[tuple[str, Callable[..., Any]]]:
return list(itertools.chain(self.async_methods.items(), self.sync_methods.items()))


_extension_method_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})
_extension_registry: dict[Extension, _ExtensionMethodRegistry] = {}


class Extension:
def __init__(
self,
includes: list[str] | None = None,
script_text: str | None = None,
# template: str | Template | None = None
includes: list[str] | None = None,
dependencies: list[Extension] | None = None,
):
self._text: str = script_text or ''
# self._template: str | Template | None = template
self._includes: list[str] = includes or []
self._extension_method_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})
self.dependencies: list[Extension] = dependencies or []
self._extension_method_registry: _ExtensionMethodRegistry = _ExtensionMethodRegistry(
sync_methods={}, async_methods={}
)
_extension_registry[self] = self._extension_method_registry

@property
def script_text(self) -> str:
Expand All @@ -100,7 +102,47 @@ def script_text(self, new_script: str) -> None:
def includes(self) -> list[Include]:
return [Include(inc) for inc in self._includes]

def register(self, f: Callable[P, T]) -> Callable[P, T]:
self._extension_method_registry.register(self, f)
_extension_method_registry.register(self, f)
def register(self, f: Callable[Concatenate[TAHK, P], T]) -> Callable[Concatenate[TAHK, P], T]:
self._extension_method_registry.register(f)
return f

def __hash__(self) -> int:
return hash((self._text, tuple(self.includes), tuple(self.dependencies)))

def __eq__(self, other: Any) -> bool:
if isinstance(other, Extension):
return hash(self) == hash(other)
return NotImplemented


def _resolve_extension(extension: Extension, seen: set[Extension]) -> list[Extension]:
ret: deque[Extension] = deque()
todo = [extension]
while todo:
ext = todo.pop()
if ext in seen:
continue
ret.appendleft(ext)
seen.add(ext)
todo.extend(ext.dependencies)
return list(ret)


def _resolve_extensions(extensions: list[Extension]) -> list[Extension]:
seen: set[Extension] = set()
ret: list[Extension] = []
for ext in extensions:
ret.extend(_resolve_extension(ext, seen=seen))
return ret


def _resolve_includes(extensions: list[Extension]) -> list[Include]:
extensions = _resolve_extensions(extensions)
ret = []
seen: set[Include] = set()
for ext in extensions:
for include in ext.includes:
if include in seen:
continue
ret.append(include)
return ret
10 changes: 0 additions & 10 deletions ahk/templates/daemon.ahk
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
#NoEnv
#Persistent
#SingleInstance Off
{% block extension_directives %}
; BEGIN extension includes
{% for ext in extensions %}
{% for inc in ext.includes %}
{{ inc }}

{% endfor %}
{% endfor %}
; END extension includes
{% endblock extension_directives %}
; BEGIN user-defined directives
{% block user_directives %}
{% for directive in directives %}
Expand Down
Loading

0 comments on commit 09bc676

Please sign in to comment.