Skip to content

Commit

Permalink
♻️ split userland module loading into utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
haliphax committed Apr 10, 2024
1 parent 87eb0be commit 4e8f5db
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 21 deletions.
31 changes: 31 additions & 0 deletions xthulu/scripting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Scripting utilities"""

# typing
from types import ModuleType

# stdlib
from importlib.machinery import ModuleSpec, PathFinder

# local
from .configuration import get_config


def load_userland_module(name: str) -> ModuleType | None:
"""Load module from userland scripts"""

pathfinder = PathFinder()
paths = get_config("ssh.userland.paths")
split: list[str] = name.split(".")
found: ModuleSpec | None = None
mod: ModuleType | None = None

for seg in split:
if mod is not None:
found = pathfinder.find_spec(seg, list(mod.__path__))
else:
found = pathfinder.find_spec(seg, paths)

if found is not None and found.loader is not None:
mod = found.loader.load_module(found.name)

return mod
23 changes: 2 additions & 21 deletions xthulu/ssh/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@

# type checking
from typing import Any, Callable, NoReturn, Optional
from types import ModuleType

# stdlib
from asyncio import Queue, sleep
from codecs import decode
from functools import partial, singledispatch
from importlib.abc import Loader
from importlib.machinery import PathFinder
import logging
import subprocess
import sys
Expand All @@ -20,19 +17,16 @@

# local
from ... import locks
from ...configuration import get_config
from ...events import EventQueue
from ...logger import log
from ...models import User
from ...scripting import load_userland_module
from ..console import XthuluConsole
from ..exceptions import Goto, ProcessClosing
from ..structs import Script
from .lock_manager import _LockManager
from .log_filter import ContextLogFilter

pathfinder = PathFinder()
"""PathFinder for loading userland script modules"""


class SSHContext:
"""Context object for SSH sessions"""
Expand Down Expand Up @@ -289,22 +283,9 @@ async def runscript(self, script: Script) -> Any:
"""

self.log.info(f"Running {script}")
split: list[str] = script.name.split(".")
found: Loader | None = None
mod: ModuleType | None = None

for seg in split:
if mod is not None:
found = pathfinder.find_module(seg, list(mod.__path__))
else:
found = pathfinder.find_module(
seg, get_config("ssh.userland.paths")
)

if found is not None:
mod = found.load_module(seg)

try:
mod = load_userland_module(script.name)
main: Callable[..., Any] = getattr(mod, "main")
return await main(self, *script.args, **script.kwargs)
except (ProcessClosing, Goto):
Expand Down

0 comments on commit 4e8f5db

Please sign in to comment.