Skip to content

Commit

Permalink
support annotated FieldInfo range types
Browse files Browse the repository at this point in the history
  • Loading branch information
eonu committed Dec 30, 2024
1 parent 1711fb4 commit 3f8f350
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 32 deletions.
66 changes: 34 additions & 32 deletions feud/_internal/_types/click.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,6 @@ def resolve_annotated(

arg_list = list(parent_args.values())
two_field_subtype = t.Annotated[*arg_list[:2]] # type: ignore[valid-type]

# integer types
if two_field_subtype == pyd.PositiveInt:
return click.IntRange(min=0, min_open=True)
Expand All @@ -476,13 +475,12 @@ def resolve_annotated(
if two_field_subtype == pyd.NonPositiveFloat:
return click.FloatRange(max=0, max_open=False)

# int / float range types
if is_pyd_conint(base_type, parent_args):
return get_click_range_type(parent_args, range_type=click.IntRange)
if is_pyd_confloat(base_type, parent_args):
return get_click_range_type(parent_args, range_type=click.FloatRange)
if is_pyd_condecimal(base_type, parent_args):
return get_click_range_type(parent_args, range_type=click.FloatRange)
# int / float / decimal range types
if interval := get_interval(parent_args):
if base_type is int:
return get_click_range_type(interval, click.IntRange)
if base_type in (float, decimal.Decimal):
return get_click_range_type(interval, click.FloatRange)

# file / directory types
if two_field_subtype == pyd.FilePath:
Expand All @@ -495,39 +493,35 @@ def resolve_annotated(
return None


def is_pyd_conint(base_type: t.Any, parent_args: AnnotatedArgDict) -> bool:
return base_type is int and isinstance(parent_args.get(2), ta.Interval)


def is_pyd_confloat(base_type: t.Any, parent_args: AnnotatedArgDict) -> bool:
return base_type is float and isinstance(parent_args.get(2), ta.Interval)


def is_pyd_condecimal(base_type: t.Any, parent_args: AnnotatedArgDict) -> bool:
return base_type is decimal.Decimal and isinstance(
parent_args.get(2), ta.Interval
)


def is_namedtuple(hint: t.Any) -> bool:
if hint is None:
return False
if not inspect.isclass(hint):
return False
return issubclass(hint, tuple) and hasattr(hint, "_fields")
def get_interval(parent_args: AnnotatedArgDict) -> ta.Interval | None:
for v in parent_args.values():
if isinstance(v, ta.Interval):
return v
if isinstance(v, pyd.fields.FieldInfo):
interval = {bound: None for bound in ("ge", "gt", "le", "lt")}
for meta in v.metadata:
match meta:
case ta.Ge():
interval["ge"] = meta.ge # type: ignore[assignment]
case ta.Gt():
interval["gt"] = meta.gt # type: ignore[assignment]
case ta.Le():
interval["le"] = meta.le # type: ignore[assignment]
case ta.Lt():
interval["lt"] = meta.lt # type: ignore[assignment]
if any(interval.values()):
return ta.Interval(**interval)
return None


def get_click_range_type(
args: AnnotatedArgDict,
*,
interval: ta.Interval,
range_type: type[click.IntRange] | type[click.FloatRange],
) -> click.IntRange | click.FloatRange | None:
min_: ta.SupportsGe | ta.SupportsGt | None = None
max_: ta.SupportsLe | ta.SupportsLt | None = None

min_open, max_open = False, False

interval: ta.Interval | None = args.get(2)
if interval is None:
return None
if interval.gt is not None:
Expand All @@ -549,3 +543,11 @@ def get_click_range_type(
min_open=min_open,
max_open=max_open,
)


def is_namedtuple(hint: t.Any) -> bool:
if hint is None:
return False
if not inspect.isclass(hint):
return False
return issubclass(hint, tuple) and hasattr(hint, "_fields")
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,17 @@
and x.max == t.Decimal("3.14")
and x.max_open is True,
),
(
t.Annotated[
t.Decimal,
pyd.Field(lt=t.Decimal("3.14"), ge=t.Decimal("0.01")),
],
lambda x: isinstance(x, click.FloatRange)
and x.min == t.Decimal("0.01")
and x.min_open is False
and x.max == t.Decimal("3.14")
and x.max_open is True,
),
(
t.confloat(lt=3.14, ge=0.01),
lambda x: isinstance(x, click.FloatRange)
Expand All @@ -183,6 +194,14 @@
and x.max == 3.14
and x.max_open is True,
),
(
t.Annotated[float, pyd.Field(lt=3.14, ge=0.01)],
lambda x: isinstance(x, click.FloatRange)
and x.min == 0.01
and x.min_open is False
and x.max == 3.14
and x.max_open is True,
),
(t.confrozenset(int, max_length=1), click.INT),
(
t.conint(lt=3, ge=0),
Expand All @@ -192,6 +211,14 @@
and x.max == 3
and x.max_open is True,
),
(
t.Annotated[int, pyd.Field(lt=3, ge=0)],
lambda x: isinstance(x, click.IntRange)
and x.min == 0
and x.min_open is False
and x.max == 3
and x.max_open is True,
),
(t.conlist(int, max_length=1), click.INT),
(t.conset(int, max_length=1), click.INT),
(t.constr(max_length=1), click.STRING),
Expand Down

0 comments on commit 3f8f350

Please sign in to comment.