Skip to content

Commit

Permalink
feat: add mypy support and type stubs (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
eonu authored Dec 29, 2024
1 parent dc049d2 commit d894865
Show file tree
Hide file tree
Showing 24 changed files with 286 additions and 150 deletions.
16 changes: 15 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
linting:
executor:
name: python/default
tag: "3.11"
tag: "3.13"
steps:
- checkout
- python/install-packages:
Expand All @@ -18,6 +18,19 @@ jobs:
name: Linting
command: |
poetry run tox -e lint
typechecks:
executor:
name: python/default
tag: "3.13"
steps:
- checkout
- python/install-packages:
pkg-manager: poetry
args: --only base
- run:
name: Typechecking (MyPy)
command: |
poetry run tox -e types
tests:
parameters:
version:
Expand Down Expand Up @@ -53,6 +66,7 @@ workflows:
checks:
jobs:
- linting
- typechecks
- tests:
matrix:
parameters:
Expand Down
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ base: .check-poetry
dev: .check-poetry
poetry install --sync --only base
poetry run invoke install

# clean temporary repository files
.PHONY: clean
clean: .check-poetry
poetry run invoke clean
36 changes: 20 additions & 16 deletions feud/_internal/_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class ParameterType(enum.Enum):
@dataclasses.dataclass
class ParameterSpec:
type: ParameterType | None = None # noqa: A003
hint: type | None = None
args: t.Iterable[t.Any] = dataclasses.field(default_factory=list)
hint: type | None = None # type: ignore[valid-type]
args: t.Sequence[str] = dataclasses.field(default_factory=list)
kwargs: dict[str, t.Any] = dataclasses.field(default_factory=dict)


Expand All @@ -52,14 +52,14 @@ class CommandState:
def decorate( # noqa: PLR0915
self: t.Self,
func: t.Callable,
) -> click.Command:
) -> click.Command | click.Group:
meta_vars: dict[str, str] = {}
sensitive_vars: dict[str, bool] = {}
positional: list[str] = []
var_positional: str | None = None
params: list[click.Parameter] = []

sig: inspect.signature = inspect.signature(func, eval_str=True)
sig: inspect.Signature = inspect.signature(func, eval_str=True)

for i, (param_name, param_spec) in enumerate(sig.parameters.items()):
# store names of positional arguments
Expand All @@ -81,7 +81,9 @@ def decorate( # noqa: PLR0915
continue
if param_name in self.overrides:
param: click.Parameter = self.overrides[param_name]
sensitive = param.hide_input or param.envvar
sensitive |= bool(param.envvar)
if isinstance(param, click.Option):
sensitive |= param.hide_input
elif param_name in self.arguments:
spec = self.arguments[param_name]
spec.kwargs["type"] = _types.click.get_click_type(
Expand All @@ -96,7 +98,7 @@ def decorate( # noqa: PLR0915
param = click.Option(spec.args, **spec.kwargs)
hide_input = spec.kwargs.get("hide_input")
envvar = spec.kwargs.get("envvar")
sensitive = hide_input or envvar
sensitive |= hide_input or bool(envvar)

# get renamed parameter if @feud.rename used
name: str = self.meta.names["params"].get(param_name, param_name)
Expand All @@ -105,7 +107,7 @@ def decorate( # noqa: PLR0915
param.name = name

# get meta vars and identify sensitive parameters for validate_call
meta_vars[name] = self.get_meta_var(param)
meta_vars[name] = self.get_meta_var(param) or name
sensitive_vars[name] = sensitive

# add the parameter
Expand Down Expand Up @@ -140,7 +142,7 @@ def decorate( # noqa: PLR0915
)

if self.pass_context:
command = click.pass_context(command)
command = click.pass_context(command) # type: ignore[assignment, arg-type]

if click.is_rich:
# apply rich-click styling
Expand All @@ -151,18 +153,20 @@ def decorate( # noqa: PLR0915
)(command)

constructor = click.group if self.is_group else click.command
command = constructor(**self.click_kwargs)(command)
compiled: click.Command | click.Group = constructor(
**self.click_kwargs,
)(command)
compiled.params = params

command.params = params
return compiled

return command

def get_meta_var(self: t.Self, param: click.Parameter) -> str:
def get_meta_var(self: t.Self, param: click.Parameter) -> str | None:
match param:
case click.Argument():
return param.make_metavar()
case click.Option():
return param.opts[0]
return None


def pass_context(sig: inspect.Signature) -> bool:
Expand All @@ -174,7 +178,7 @@ def pass_context(sig: inspect.Signature) -> bool:
return param_name == CONTEXT_PARAM


def get_option(name: str, *, hint: type, negate_flags: bool) -> str:
def get_option(name: str, *, hint: t.Any, negate_flags: bool) -> str:
"""Convert a name into a command-line option.
Additionally negates the option if a boolean flag is provided
Expand All @@ -193,7 +197,7 @@ def get_option(name: str, *, hint: type, negate_flags: bool) -> str:
return option


def get_alias(alias: str, *, hint: type, negate_flags: bool) -> str:
def get_alias(alias: str, *, hint: t.Any, negate_flags: bool) -> str:
"""Negate an alias for a boolean flag and returns a joint declaration
if ``negate_flags`` is ``True``.
Expand Down Expand Up @@ -400,5 +404,5 @@ def get_command(

# generate click.Command and attach original function reference
command = state.decorate(func)
command.__func__ = func
command.__func__ = func # type: ignore[attr-defined]
return command
31 changes: 21 additions & 10 deletions feud/_internal/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from feud import click

AnyCallableT = t.TypeVar("AnyCallableT", bound=t.Callable[..., t.Any])


def validate_call(
func: t.Callable,
Expand All @@ -27,7 +29,7 @@ def validate_call(
positional: list[str],
var_positional: str | None,
pydantic_kwargs: dict[str, t.Any],
) -> t.Callable:
) -> t.Callable[[AnyCallableT], AnyCallableT]:
@ft.wraps(func)
def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Callable:
try:
Expand All @@ -38,22 +40,31 @@ def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Callable:
args += (pos_arg,)

# move *args to positional arguments
var_pos_args = kwargs.pop(
var_positional,
inspect._empty, # noqa: SLF001
)
if var_pos_args is not inspect._empty: # noqa: SLF001
args += var_pos_args
if var_positional is not None:
var_pos_args = kwargs.pop(
var_positional,
inspect._empty, # noqa: SLF001
)
if var_pos_args is not inspect._empty: # noqa: SLF001
args += var_pos_args

# apply renaming for any options
inv_mapping = {v: k for k, v in param_renames.items()}
true_kwargs = {inv_mapping.get(k, k): v for k, v in kwargs.items()}

# create Pydantic configuration
config = pyd.ConfigDict(**pydantic_kwargs)
config = pyd.ConfigDict(
**pydantic_kwargs, # type: ignore[typeddict-item]
)

# validate the function call
return pyd.validate_call(func, config=config)(*args, **true_kwargs)
return pyd.validate_call( # type: ignore[call-overload]
func,
config=config,
)(
*args,
**true_kwargs,
)
except pyd.ValidationError as e:
msg = re.sub(
r"validation error(s?) for (.*)\n",
Expand Down Expand Up @@ -85,4 +96,4 @@ def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Callable:
)
raise click.ClickException(msg) from None

return wrapper
return wrapper # type: ignore[return-value]
2 changes: 1 addition & 1 deletion feud/_internal/_docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_description(
doc = docstring_parser.parse_from_object(func)
elif isinstance(obj, docstring_parser.Docstring):
doc = obj
elif isinstance(obj, t.Callable):
elif callable(obj):
doc = docstring_parser.parse_from_object(obj)

ret = None
Expand Down
17 changes: 12 additions & 5 deletions feud/_internal/_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
# SPDX-License-Identifier: MIT
# This source code is part of the Feud project (https://feud.wiki).

from __future__ import annotations

import typing as t

from feud import click
from feud._internal import _command, _meta

if t.TYPE_CHECKING:
from feud.core.group import Group


def get_group(__cls: type, /) -> click.Group: # type[Group]
func: callable = __cls.__main__
def get_group(__cls: type[Group], /) -> click.Group: # type[Group]
func: t.Callable = __cls.__main__
if isinstance(func, staticmethod):
func = func.__func__

Expand All @@ -29,7 +36,7 @@ def get_group(__cls: type, /) -> click.Group: # type[Group]
)

# generate click.Group and attach original function reference
command = state.decorate(func)
command.__func__ = func
command.__group__ = __cls
command: click.Group = state.decorate(func) # type: ignore[assignment]
command.__func__ = func # type: ignore[attr-defined]
command.__group__ = __cls # type: ignore[attr-defined]
return command
14 changes: 11 additions & 3 deletions feud/_internal/_metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
from feud.config import Config
from feud.core.command import command

if t.TYPE_CHECKING:
from feud.core.group import Group


class GroupBase(abc.ABCMeta):
def __new__(
__cls: type[GroupBase], # noqa: N804
cls_name: str,
bases: tuple[type, ...],
bases: tuple[type[Group], ...],
namespace: dict[str, t.Any],
**kwargs: t.Any,
) -> type: # type[Group], but circular import
) -> type[Group]:
"""Metaclass for creating groups.
Parameters
Expand Down Expand Up @@ -124,7 +127,12 @@ def __new__(
func, config=namespace["__feud_config__"]
)

group = super().__new__(__cls, cls_name, bases, namespace)
group: type[Group] = super().__new__( # type: ignore[assignment]
__cls,
cls_name,
bases,
namespace,
)

if bases:
# use class-level docstring as help if provided
Expand Down
33 changes: 13 additions & 20 deletions feud/_internal/_sections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,18 @@

import dataclasses
from collections import defaultdict
from typing import TypedDict

from feud import click
from feud._internal import _meta


class CommandGroup(TypedDict):
name: str
commands: list[str]
def add_command_sections(group: click.Group, context: list[str]) -> None:
from rich_click.utils import CommandGroupDict


class OptionGroup(TypedDict):
name: str
options: list[str]


def add_command_sections(
group: click.Group, context: list[str]
) -> click.Group:
if feud_group := getattr(group, "__group__", None):
command_groups: dict[str, list[CommandGroup]] = {
command_groups: dict[str, list[CommandGroupDict]] = {
" ".join(context): [
CommandGroup(
CommandGroupDict(
name=section.name,
commands=[
item if isinstance(item, str) else item.name()
Expand All @@ -41,7 +30,7 @@ def add_command_sections(
}

for sub in group.commands.values():
if isinstance(sub, click.Group):
if sub.name and isinstance(sub, click.Group):
add_command_sections(sub, context=[*context, sub.name])

settings = group.context_settings
Expand All @@ -57,10 +46,12 @@ def add_command_sections(

def add_option_sections(
obj: click.Command | click.Group, context: list[str]
) -> click.Command | click.Group:
) -> None:
if isinstance(obj, click.Group):
update_command(obj, context=context)
for sub in obj.commands.values():
if sub.name is None:
continue
if isinstance(sub, click.Group):
add_option_sections(sub, context=[*context, sub.name])
else:
Expand All @@ -70,7 +61,7 @@ def add_option_sections(


def get_opts(option: str, *, command: click.Command) -> list[str]:
func = command.__func__
func = command.__func__ # type: ignore[attr-defined]
name_map = lambda name: name # noqa: E731
meta: _meta.FeudMeta | None = getattr(func, "__feud__", None)
if meta and meta.names:
Expand All @@ -84,6 +75,8 @@ def get_opts(option: str, *, command: click.Command) -> list[str]:


def update_command(command: click.Command, context: list[str]) -> None:
from rich_click.utils import OptionGroupDict

if func := getattr(command, "__func__", None):
meta: _meta.FeudMeta | None = getattr(func, "__feud__", None)
if meta and meta.sections:
Expand All @@ -92,9 +85,9 @@ def update_command(command: click.Command, context: list[str]) -> None:
for option, section_name in options.items():
opts: list[str] = get_opts(option, command=command)
sections[section_name].append(opts[0])
option_groups: dict[str, list[OptionGroup]] = {
option_groups: dict[str, list[OptionGroupDict]] = {
" ".join(context): [
OptionGroup(name=name, options=options)
OptionGroupDict(name=name, options=options)
for name, options in sections.items()
]
}
Expand Down
Loading

0 comments on commit d894865

Please sign in to comment.