Skip to content

Commit

Permalink
fix fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
devkral committed Nov 11, 2024
1 parent 1e429db commit d85fcca
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 42 deletions.
52 changes: 14 additions & 38 deletions monkay/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def load(path: str, allow_splits: str = ":.") -> Any:
return getattr(module, splitted[1])


def load_any(
path: str, attrs: Sequence[str], *, non_first_deprecated: bool = False
) -> Any | None:
def load_any(path: str, attrs: Sequence[str], *, non_first_deprecated: bool = False) -> Any | None:
module = import_module(path)
first_name: None | str = None

Expand Down Expand Up @@ -78,9 +76,7 @@ class Monkay(Generic[L]):
# extensions are pretended to always exist, we check the _extensions_var
_extensions: dict[str, ExtensionProtocol[L]]
_extensions_var: None | ContextVar[None | dict[str, ExtensionProtocol[L]]] = None
_extensions_applied: None | ContextVar[dict[str, ExtensionProtocol[L]] | None] = (
None
)
_extensions_applied: None | ContextVar[dict[str, ExtensionProtocol[L]] | None] = None
_settings_var: ContextVar[BaseSettings | None] | None = None

def __init__(
Expand Down Expand Up @@ -117,9 +113,7 @@ def __init__(
), f"Lazy imports and lazy deprecated imports share: {', '.join(set(self.lazy_imports).intersection(self.deprecated_lazy_imports))}"
self.settings_path = settings_path
if self.settings_path:
self._settings_var = global_dict[settings_ctx_name] = ContextVar(
settings_ctx_name, default=None
)
self._settings_var = global_dict[settings_ctx_name] = ContextVar(settings_ctx_name, default=None)

self.settings_preload_name = settings_preload_name
self.settings_extensions_name = settings_extensions_name
Expand All @@ -134,17 +128,13 @@ def __init__(
all_var = global_dict.setdefault("__all__", [])
global_dict["__all__"] = self.update_all_var(all_var)
if with_instance:
self._instance_var = global_dict[with_instance] = ContextVar(
with_instance, default=None
)
self._instance_var = global_dict[with_instance] = ContextVar(with_instance, default=None)
if with_extensions:
self.extension_order_key_fn = extension_order_key_fn
self._extensions = {}
self._extensions_var = global_dict[with_extensions] = ContextVar(
with_extensions, default=None
)
self._extensions_applied_var = global_dict[extensions_applied_ctx_name] = (
ContextVar(extensions_applied_ctx_name, default=None)
self._extensions_var = global_dict[with_extensions] = ContextVar(with_extensions, default=None)
self._extensions_applied_var = global_dict[extensions_applied_ctx_name] = ContextVar(
extensions_applied_ctx_name, default=None
)
self._handle_extensions()

Expand Down Expand Up @@ -191,9 +181,7 @@ def with_instance(

def apply_extensions(self, use_overwrite: bool = True) -> None:
assert self._extensions_var is not None, "Monkay not enabled for extensions"
extensions: dict[str, ExtensionProtocol[L]] | None = (
self._extensions_var.get() if use_overwrite else None
)
extensions: dict[str, ExtensionProtocol[L]] | None = self._extensions_var.get() if use_overwrite else None
if extensions is None:
extensions = self._extensions
extensions_applied = self._extensions_applied_var.get()
Expand Down Expand Up @@ -227,15 +215,11 @@ def ensure_extension(self, name_or_extension: str | ExtensionProtocol[L]) -> Non
if isinstance(name_or_extension, str):
name = name_or_extension
extension = extensions.get(name)
elif not isclass(name_or_extension) and isinstance(
name_or_extension, ExtensionProtocol
):
elif not isclass(name_or_extension) and isinstance(name_or_extension, ExtensionProtocol):
name = name_or_extension.name
extension = extensions.get(name, name_or_extension)
else:
raise RuntimeError(
'Provided extension "{name_or_extension}" does not implement the ExtensionProtocol'
)
raise RuntimeError('Provided extension "{name_or_extension}" does not implement the ExtensionProtocol')
if name in self._extensions_applied_var.get():
return

Expand All @@ -246,15 +230,11 @@ def ensure_extension(self, name_or_extension: str | ExtensionProtocol[L]) -> Non

def add_extension(
self,
extension: ExtensionProtocol[L]
| type[ExtensionProtocol[L]]
| Callable[[], ExtensionProtocol[L]],
extension: ExtensionProtocol[L] | type[ExtensionProtocol[L]] | Callable[[], ExtensionProtocol[L]],
use_overwrite: bool = True,
) -> None:
assert self._extensions_var is not None, "Monkay not enabled for extensions"
extensions: dict[str, ExtensionProtocol[L]] | None = (
self._extensions_var.get() if use_overwrite else None
)
extensions: dict[str, ExtensionProtocol[L]] | None = self._extensions_var.get() if use_overwrite else None
if extensions is None:
extensions = self._extensions
if callable(extension) or isclass(extension):
Expand Down Expand Up @@ -313,9 +293,7 @@ def with_settings(self, settings: BaseSettings | None) -> Generator:
finally:
self._settings_var.reset(token)

def module_getter(
self, key: str, *, chained_getter: Callable[[str], Any] = _stub_previous_getattr
) -> Any:
def module_getter(self, key: str, *, chained_getter: Callable[[str], Any] = _stub_previous_getattr) -> Any:
lazy_import = self.lazy_imports.get(key)
if lazy_import is None:
deprecated = self.deprecated_lazy_imports.get(key)
Expand All @@ -336,9 +314,7 @@ def module_getter(

def _handle_preloads(self, preloads: Iterable[str]) -> None:
if self.settings_preload_name:
preloads = chain(
preloads, getattr(self.settings, self.settings_preload_name)
)
preloads = chain(preloads, getattr(self.settings, self.settings_preload_name))
for preload in preloads:
splitted = preload.rsplit(":", 1)
try:
Expand Down
5 changes: 1 addition & 4 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def test_attrs():
assert mod.bar() == "bar"
with pytest.warns(DeprecationWarning) as record:
assert mod.deprecated() == "deprecated"
assert (
record[0].message.args[0]
== 'Attribute: "deprecated" is deprecated.\nReason: old.\nUse "super_new" instead.'
)
assert record[0].message.args[0] == 'Attribute: "deprecated" is deprecated.\nReason: old.\nUse "super_new" instead.'


def test_load():
Expand Down

0 comments on commit d85fcca

Please sign in to comment.