Skip to content

Commit

Permalink
Zc/dependency injection (#907)
Browse files Browse the repository at this point in the history
A very simple dependency injection framework for seer.
  • Loading branch information
corps authored Jul 15, 2024
1 parent f576ad2 commit 662071a
Show file tree
Hide file tree
Showing 5 changed files with 400 additions and 3 deletions.
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
"unidiff",
"tree_sitter_languages",
"tree_sitter",
"johen.*",
"johen",
"google-cloud-storage",
"langfuse.*",
"langfuse",
Expand Down
5 changes: 5 additions & 0 deletions src/seer/bootup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from celery_app.config import CeleryQueues
from seer.automation.utils import AgentError
from seer.db import Session, db, migrate
from seer.dependency_injection import Module

logger = logging.getLogger(__name__)
structlog.configure(
Expand All @@ -24,6 +25,9 @@
]
)

module = Module()
stub_module = Module()


def traces_sampler(sampling_context: dict):
if "wsgi_environ" in sampling_context:
Expand Down Expand Up @@ -98,6 +102,7 @@ def before_send(event, hint):

torch.set_num_threads(int(torch_num_threads))

module.enable()
return app


Expand Down
257 changes: 257 additions & 0 deletions src/seer/dependency_injection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
"""
Provides a basic dependency injection framework that uses callable annotations
to decide how and when to inject.
You can inject classes and values and lists of either with some basic constructs:
@module.provider
@dataclass
class MyService:
other_service: OtherService = injected
MyService() # other_service will be instantiated and cached
Overrides/Stubs and tests can be provided via the `stub_module` or creating a `test_module` fixture
(see conftest.py).
You can also inject normal functions, like so:
@inject
def do_setup(a: int, b: MyService = injected):
...
do_setup(100) # b will be injected automatically.
"""

import dataclasses
import functools
import inspect
import threading
from typing import Annotated, Any, Callable, TypeVar

from johen.generators.annotations import AnnotationProcessingContext

_A = TypeVar("_A")
_C = TypeVar("_C", bound=Callable[[], Any])
_T = TypeVar("_T", bound=type)


@dataclasses.dataclass
class Labeled:
"""
Used to 'label' a type so as to have a unique provider when the type itself is not unique.
eg:
@inject
@dataclass
class Config:
host: Annotated[str, Labeled("host")]
protocol: Annotated[str, Labeled("protocol")]
"""

label: str


@dataclasses.dataclass(frozen=True)
class FactoryAnnotation:
concrete_type: type
is_collection: bool
label: str

@classmethod
def from_annotation(cls, source: Any) -> "FactoryAnnotation":
annotation = AnnotationProcessingContext.from_source(source)
if annotation.origin is Annotated:
label = next((arg.label for arg in annotation.args[1:] if isinstance(arg, Labeled)), "")
inner = FactoryAnnotation.from_annotation(annotation.args[0])
assert not inner.label, f"Cannot get_factory {source}: Annotated has embedded Labeled"
return dataclasses.replace(inner, label=label)
elif annotation.concretely_implements(list):
assert (
len(annotation.args) == 1
), f"Cannot get_factory {source}: list requires at least one argument"
inner = FactoryAnnotation.from_annotation(annotation.args[0])
assert not inner.label, f"Cannot get_factory {source}: list has embedded Labeled"
assert (
not inner.is_collection
), f"Cannot get_factory {source}: collections must be of concrete types, not other lists"
return dataclasses.replace(inner, is_collection=True)

assert (
annotation.origin is None
), f"Cannot get_factory {source}, only concrete types or lists of concrete types are supported"
return FactoryAnnotation(concrete_type=annotation.source, is_collection=False, label="")

@classmethod
def from_factory(cls, c: Callable) -> "FactoryAnnotation":
argspec = inspect.getfullargspec(c)
num_arg_defaults = len(argspec.defaults) if argspec.defaults is not None else 0
num_kwd_defaults = len(argspec.kwonlydefaults) if argspec.kwonlydefaults is not None else 0

# Constructors have implicit self reference and return annotations -- themselves
if inspect.isclass(c):
num_arg_defaults += 1
rv = c
else:
rv = argspec.annotations.get("return", None)
assert rv is not None, "Cannot decorate function without return annotation"

assert num_arg_defaults >= len(
argspec.args
), "Cannot decorate function with required positional args"
assert num_kwd_defaults >= len(
argspec.kwonlyargs
), "Cannot decorate function with required kwd args"
return FactoryAnnotation.from_annotation(rv)


class FactoryNotFound(Exception):
pass


@dataclasses.dataclass
class Module:
registry: dict[FactoryAnnotation, Callable] = dataclasses.field(default_factory=dict)

def provider(self, c: _C) -> _C:
c = inject(c)

key = FactoryAnnotation.from_factory(c)
assert (
key not in self.registry
), f"{key.concrete_type} is already registered for this injector"
self.registry[key] = c
return c

def constant(self, annotation: type[_A], val: _A) -> _A:
key = FactoryAnnotation.from_annotation(annotation)
self.registry[key] = lambda: val
return val

def enable(self):
injector = Injector(self, _cur.injector)
_cur.injector = injector
return injector

def __enter__(self):
return self.enable()

def __exit__(self, exc_type, exc_val, exc_tb):
assert _cur.injector, "Injector state was tampered with, or __exit__ invoked prematurely"
assert (
_cur.injector.module is self
), "Injector state was tampered with, or __exit__ invoked prematurely"
_cur.injector = _cur.injector.parent


class _Injected:
"""
Magical variable indicating that a parameter should be injected when constructed
by an Injector object. Invoking a method that uses an `injected` value directly
will use the currently available injector instance to fill in the default value.
"""

pass


# Marked as Any so it can be a stand in value for any annotation.
injected: Any = _Injected()


def inject(c: _A) -> _A:
original_type = c
if inspect.isclass(c):
c = c.__init__

argspec = inspect.getfullargspec(c)

@functools.wraps(c) # type: ignore
def wrapper(*args: Any, **kwargs: Any) -> Any:
new_kwds = {**kwargs}

if argspec.defaults:
offset = len(argspec.args) - len(argspec.defaults)
for i, d in enumerate(argspec.defaults):
arg_idx = offset + i
arg_name = argspec.args[arg_idx]

if d is injected and len(args) <= arg_idx and arg_name not in new_kwds:
try:
resolved = resolve(argspec.annotations[arg_name])
except KeyError:
raise AssertionError(
f"Cannot inject argument {arg_name} as it lacks annotations"
)

new_kwds[arg_name] = resolved

if argspec.kwonlydefaults:
for k, v in argspec.kwonlydefaults.items():
if v is injected and k not in new_kwds:
try:
new_kwds[k] = resolve(argspec.annotations[k])
except KeyError:
raise AssertionError(f"Cannot inject argument {k} as it lacks annotations")

return c(*args, **new_kwds) # type: ignore

if inspect.isclass(original_type):
return type(original_type.__name__, (original_type,), dict(__init__=wrapper)) # type: ignore

return wrapper # type: ignore


def resolve(source: type[_A]) -> _A:
if _cur.injector is None:
raise FactoryNotFound(f"Cannot resolve '{source}', no module injector is currently active.")

key = FactoryAnnotation.from_annotation(source)

if _cur.seen is None:
_cur.seen = []

try:
if key in _cur.seen:
raise FactoryNotFound(
f"Circular dependency: {' -> '.join(str(k) for k in _cur.seen)} -> {key}"
)
_cur.seen.append(key)
return _cur.injector.get(source)
finally:
_cur.seen.clear()


@dataclasses.dataclass
class Injector:
module: Module
parent: "Injector | None"
_cache: dict[FactoryAnnotation, Any] = dataclasses.field(default_factory=dict)

@property
def cache(self) -> dict[FactoryAnnotation, Any]:
if _cur.injector is not None:
return _cur.injector._cache
return self._cache

def get(self, source: type[_A]) -> _A:
key = FactoryAnnotation.from_annotation(source)
if key in self.cache:
return self.cache[key]

try:
f = self.module.registry[key]
except KeyError:
if self.parent is not None:
return self.parent.get(source)
raise FactoryNotFound(f"No registered factory for {source}")

rv = self.cache[key] = f()
return rv


class _Cur(threading.local):
injector: Injector | None = None
seen: list[FactoryAnnotation] | None = None


_cur = _Cur()
14 changes: 13 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,23 @@
from johen.generators import pydantic, sqlalchemy
from sqlalchemy import text

from seer.bootup import CELERY_CONFIG
from seer.bootup import CELERY_CONFIG, stub_module
from seer.db import Session, db
from seer.dependency_injection import Module
from seer.inference_models import reset_loading_state


@pytest.fixture
def test_module() -> Module:
return stub_module


@pytest.fixture(autouse=True)
def enable_test_injector(test_module: Module) -> None:
with test_module:
yield


@pytest.fixture(autouse=True, scope="session")
def configure_environment():
os.environ["DATABASE_URL"] = os.environ["DATABASE_URL"].replace("db", "test-db")
Expand Down
Loading

0 comments on commit 662071a

Please sign in to comment.