Skip to content

Commit

Permalink
initial work to enable extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
spyoungtech committed Aug 22, 2023
1 parent 7615f93 commit dbcc090
Show file tree
Hide file tree
Showing 10 changed files with 342 additions and 13 deletions.
36 changes: 35 additions & 1 deletion ahk/_async/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile
import time
import warnings
from functools import partial
from typing import Any
from typing import Awaitable
from typing import Callable
Expand Down Expand Up @@ -34,6 +35,7 @@
else:
from typing import TypeAlias

from ..extensions import Extension, _extension_method_registry, _ExtensionMethodRegistry
from ..keys import Key
from .transport import AsyncDaemonProcessTransport
from .transport import AsyncFutureResult
Expand Down Expand Up @@ -135,13 +137,45 @@ def __init__(
TransportClass: Optional[Type[AsyncTransport]] = None,
directives: Optional[list[Directive | Type[Directive]]] = None,
executable_path: str = '',
extensions: list[Extension] | None | Literal['auto'] = None,
):
self._extension_registry: _ExtensionMethodRegistry
self._extensions: list[Extension]
if extensions == 'auto':
is_async = False
is_async = True # unasync: remove
if is_async:
extensions = [entry.extension for name, entry in _extension_method_registry.async_methods.items()]
else:
extensions = [entry.extension for name, entry in _extension_method_registry.sync_methods.items()]
self._extension_registry = _extension_method_registry
self._extensions = extensions
else:
self._extensions = extensions or []
self._extension_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})
for ext in self._extensions:
self._extension_registry.merge(ext._extension_method_registry)

if TransportClass is None:
TransportClass = AsyncDaemonProcessTransport
assert TransportClass is not None
transport = TransportClass(executable_path=executable_path, directives=directives)
transport = TransportClass(executable_path=executable_path, directives=directives, extensions=extensions)
self._transport: AsyncTransport = transport

def __getattr__(self, name: str) -> Callable[..., Any]:
is_async = False
is_async = True # unasync: remove
if is_async:
if name in self._extension_registry.async_methods:
method = self._extension_registry.async_methods[name].method
return partial(method, self)
else:
if name in self._extension_registry.sync_methods:
method = self._extension_registry.sync_methods[name].method
return partial(method, self)

raise AttributeError(f'{self.__class__.__name__!r} object has no attribute {name!r}')

def add_hotkey(
self, keyname: str, callback: Callable[[], Any], ex_handler: Optional[Callable[[str, Exception], Any]] = None
) -> None:
Expand Down
7 changes: 6 additions & 1 deletion ahk/_async/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import jinja2

from ahk.extensions import Extension
from ahk._hotkey import ThreadedHotkeyTransport, Hotkey, Hotstring
from ahk.message import RequestMessage
from ahk.message import ResponseMessage
Expand Down Expand Up @@ -656,7 +657,9 @@ def __init__(
directives: Optional[list[Directive | Type[Directive]]] = None,
jinja_loader: Optional[jinja2.BaseLoader] = None,
template: Optional[jinja2.Template] = None,
extensions: list[Extension] | None = None,
):
self._extensions = extensions or []
self._proc: Optional[AsyncAHKProcess]
self._proc = None
self._temp_script: Optional[str] = None
Expand Down Expand Up @@ -711,7 +714,9 @@ def _render_script(self, template: Optional[jinja2.Template] = None, **kwargs: A
template = self._template
kwargs['daemon'] = self.__template
message_types = {str(tom, 'utf-8'): c.__name__.upper() for tom, c in _message_registry.items()}
return template.render(directives=self._directives, message_types=message_types, **kwargs)
return template.render(
directives=self._directives, message_types=message_types, extensions=self._extensions, **kwargs
)

@property
def lock(self) -> Any:
Expand Down
14 changes: 14 additions & 0 deletions ahk/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@
#NoEnv
#Persistent
#SingleInstance Off
{% block extension_directives %}
; BEGIN extension includes
{% for ext in extensions %}
{% for inc in ext.includes %}
{{ inc }}
{% endfor %}
{% endfor %}
; END extension includes
{% endblock extension_directives %}
; BEGIN user-defined directives
{% block user_directives %}
{% for directive in directives %}
Expand Down Expand Up @@ -2833,7 +2842,12 @@
return decoded_commands
}
; BEGIN extension scripts
{% for ext in extensions %}
{{ ext.script_text }}
{% endfor %}
; END extension scripts
{% block before_autoexecute %}
{% endblock before_autoexecute %}
Expand Down
35 changes: 34 additions & 1 deletion ahk/_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile
import time
import warnings
from functools import partial
from typing import Any
from typing import Awaitable
from typing import Callable
Expand Down Expand Up @@ -34,6 +35,7 @@
else:
from typing import TypeAlias

from ..extensions import Extension, _extension_method_registry, _ExtensionMethodRegistry
from ..keys import Key
from .transport import DaemonProcessTransport
from .transport import FutureResult
Expand Down Expand Up @@ -131,13 +133,44 @@ def __init__(
TransportClass: Optional[Type[Transport]] = None,
directives: Optional[list[Directive | Type[Directive]]] = None,
executable_path: str = '',
extensions: list[Extension] | None | Literal['auto'] = None
):
self._extension_registry: _ExtensionMethodRegistry
self._extensions: list[Extension]
if extensions == 'auto':
is_async = False
if is_async:
extensions = [entry.extension for name, entry in _extension_method_registry.async_methods.items()]
else:
extensions = [entry.extension for name, entry in _extension_method_registry.sync_methods.items()]
self._extension_registry = _extension_method_registry
self._extensions = extensions
else:
self._extensions = extensions or []
self._extension_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})
for ext in self._extensions:
self._extension_registry.merge(ext._extension_method_registry)


if TransportClass is None:
TransportClass = DaemonProcessTransport
assert TransportClass is not None
transport = TransportClass(executable_path=executable_path, directives=directives)
transport = TransportClass(executable_path=executable_path, directives=directives, extensions=extensions)
self._transport: Transport = transport

def __getattr__(self, name: str) -> Callable[..., Any]:
is_async = False
if is_async:
if name in self._extension_registry.async_methods:
method = self._extension_registry.async_methods[name].method
return partial(method, self)
else:
if name in self._extension_registry.sync_methods:
method = self._extension_registry.sync_methods[name].method
return partial(method, self)

raise AttributeError(f'{self.__class__.__name__!r} object has no attribute {name!r}')

def add_hotkey(
self, keyname: str, callback: Callable[[], Any], ex_handler: Optional[Callable[[str, Exception], Any]] = None
) -> None:
Expand Down
5 changes: 4 additions & 1 deletion ahk/_sync/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import jinja2

from ahk.extensions import Extension
from ahk._hotkey import ThreadedHotkeyTransport, Hotkey, Hotstring
from ahk.message import RequestMessage
from ahk.message import ResponseMessage
Expand Down Expand Up @@ -630,7 +631,9 @@ def __init__(
directives: Optional[list[Directive | Type[Directive]]] = None,
jinja_loader: Optional[jinja2.BaseLoader] = None,
template: Optional[jinja2.Template] = None,
extensions: list[Extension] | None = None
):
self._extensions = extensions or []
self._proc: Optional[SyncAHKProcess]
self._proc = None
self._temp_script: Optional[str] = None
Expand Down Expand Up @@ -684,7 +687,7 @@ def _render_script(self, template: Optional[jinja2.Template] = None, **kwargs: A
template = self._template
kwargs['daemon'] = self.__template
message_types = {str(tom, 'utf-8'): c.__name__.upper() for tom, c in _message_registry.items()}
return template.render(directives=self._directives, message_types=message_types, **kwargs)
return template.render(directives=self._directives, message_types=message_types, extensions=self._extensions, **kwargs)

@property
def lock(self) -> Any:
Expand Down
12 changes: 3 additions & 9 deletions ahk/_sync/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,11 @@ def __hash__(self) -> int:
return hash(self._ahk_id)

def close(self) -> None:
self._engine.win_close(
title=f'ahk_id {self._ahk_id}', detect_hidden_windows=True, title_match_mode=(1, 'Fast')
)
self._engine.win_close(title=f'ahk_id {self._ahk_id}', detect_hidden_windows=True, title_match_mode=(1, 'Fast'))
return None

def kill(self) -> None:
self._engine.win_kill(
title=f'ahk_id {self._ahk_id}', detect_hidden_windows=True, title_match_mode=(1, 'Fast')
)
self._engine.win_kill(title=f'ahk_id {self._ahk_id}', detect_hidden_windows=True, title_match_mode=(1, 'Fast'))

def exists(self) -> bool:
return self._engine.win_exists(
Expand Down Expand Up @@ -591,9 +587,7 @@ def set_transparent(
blocking=blocking,
)

def set_trans_color(
self, color: Union[int, str], *, blocking: bool = True
) -> Union[None, FutureResult[None]]:
def set_trans_color(self, color: Union[int, str], *, blocking: bool = True) -> Union[None, FutureResult[None]]:
return self._engine.win_set_trans_color(
color=color,
title=f'ahk_id {self._ahk_id}',
Expand Down
99 changes: 99 additions & 0 deletions ahk/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

import asyncio
import warnings
from dataclasses import dataclass
from typing import Any
from typing import Callable
from typing import ParamSpec
from typing import TypeVar

from .directives import Include


@dataclass
class _ExtensionEntry:
extension: Extension
method: Callable[..., Any]


@dataclass
class _ExtensionMethodRegistry:
sync_methods: dict[str, _ExtensionEntry]
async_methods: dict[str, _ExtensionEntry]

def register(self, ext: Extension, f: Callable[P, T]) -> Callable[P, T]:
if asyncio.iscoroutinefunction(f):
if f.__name__ in self.async_methods:
warnings.warn(
f'Method of name {f.__name__!r} has already been registered. '
f'Previously registered method {self.async_methods[f.__name__].method!r} '
f'will be overridden by {f!r}'
)
self.async_methods[f.__name__] = _ExtensionEntry(extension=ext, method=f)
else:
if f.__name__ in self.sync_methods:
warnings.warn(
f'Method of name {f.__name__!r} has already been registered. '
f'Previously registered method {self.sync_methods[f.__name__].method!r} '
f'will be overridden by {f!r}'
)
self.sync_methods[f.__name__] = _ExtensionEntry(extension=ext, method=f)
return f

def merge(self, other: _ExtensionMethodRegistry) -> None:
for fname, entry in other.async_methods.items():
async_method = entry.method
if async_method.__name__ in self.async_methods:
warnings.warn(
f'Method of name {async_method.__name__!r} has already been registered. '
f'Previously registered method {self.async_methods[async_method.__name__].method!r} '
f'will be overridden by {async_method!r}'
)
self.async_methods[async_method.__name__] = entry
for fname, entry in other.sync_methods.items():
method = entry.method
if method.__name__ in self.sync_methods:
warnings.warn(
f'Method of name {method.__name__!r} has already been registered. '
f'Previously registered method {self.sync_methods[method.__name__].method!r} '
f'will be overridden by {method!r}'
)
self.sync_methods[method.__name__] = entry


_extension_method_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})


T = TypeVar('T')
P = ParamSpec('P')


class Extension:
def __init__(
self,
includes: list[str] | None = None,
script_text: str | None = None,
# template: str | Template | None = None
):
self._text: str = script_text or ''
# self._template: str | Template | None = template
self._includes: list[str] = includes or []
self._extension_method_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})

@property
def script_text(self) -> str:
return self._text

@script_text.setter
def script_text(self, new_script: str) -> None:
self._text = new_script

@property
def includes(self) -> list[Include]:
return [Include(inc) for inc in self._includes]

def register(self, f: Callable[P, T]) -> Callable[P, T]:
self._extension_method_registry.register(self, f)
_extension_method_registry.register(self, f)
return f
14 changes: 14 additions & 0 deletions ahk/templates/daemon.ahk
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
#NoEnv
#Persistent
#SingleInstance Off
{% block extension_directives %}
; BEGIN extension includes
{% for ext in extensions %}
{% for inc in ext.includes %}
{{ inc }}

{% endfor %}
{% endfor %}
; END extension includes
{% endblock extension_directives %}
; BEGIN user-defined directives
{% block user_directives %}
{% for directive in directives %}
Expand Down Expand Up @@ -2830,7 +2839,12 @@ CommandArrayFromQuery(ByRef text) {
return decoded_commands
}

; BEGIN extension scripts
{% for ext in extensions %}
{{ ext.script_text }}

{% endfor %}
; END extension scripts
{% block before_autoexecute %}
{% endblock before_autoexecute %}

Expand Down
Loading

0 comments on commit dbcc090

Please sign in to comment.