From 49f5f8c9f5fc1adb5cebef00be1ef8f30790d847 Mon Sep 17 00:00:00 2001 From: Tomasz Chalupnik Date: Tue, 24 Sep 2024 18:26:31 +0200 Subject: [PATCH] Overload Variables class for better typing experience --- .../contrib/regular_languages/compiler.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/prompt_toolkit/contrib/regular_languages/compiler.py b/src/prompt_toolkit/contrib/regular_languages/compiler.py index dd558a68a..699a600f6 100644 --- a/src/prompt_toolkit/contrib/regular_languages/compiler.py +++ b/src/prompt_toolkit/contrib/regular_languages/compiler.py @@ -42,7 +42,7 @@ from __future__ import annotations import re -from typing import Callable, Dict, Iterable, Iterator, Pattern +from typing import Callable, Dict, Iterable, Iterator, Pattern, TypeVar, overload from typing import Match as RegexMatch from .regex_parser import ( @@ -57,9 +57,7 @@ tokenize_regex, ) -__all__ = [ - "compile", -] +__all__ = ["compile", "Match", "Variables"] # Name of the named group in the regex, matching trailing input. @@ -491,6 +489,9 @@ def end_nodes(self) -> Iterable[MatchVariable]: yield MatchVariable(varname, value, (reg[0], reg[1])) +_T = TypeVar("_T") + + class Variables: def __init__(self, tuples: list[tuple[str, str, tuple[int, int]]]) -> None: #: List of (varname, value, slice) tuples. @@ -502,7 +503,13 @@ def __repr__(self) -> str: ", ".join(f"{k}={v!r}" for k, v, _ in self._tuples), ) - def get(self, key: str, default: str | None = None) -> str | None: + @overload + def get(self, key: str) -> str | None: ... + + @overload + def get(self, key: str, default: _T = None) -> str | _T: ... + + def get(self, key: str, default: _T = None) -> str | _T: items = self.getall(key) return items[0] if items else default