Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typed interface of calculate_implied_probabilities() #4

Open
peterschutt opened this issue Jan 30, 2024 · 2 comments
Open

Improve typed interface of calculate_implied_probabilities() #4

peterschutt opened this issue Jan 30, 2024 · 2 comments

Comments

@peterschutt
Copy link
Contributor

peterschutt commented Jan 30, 2024

The function has a boolean flag argument that alters its output so we can use @overload to specify the relationship between args and output types.

For example,

reveal_type(shin.calculate_implied_probabilities((1.5, 2.5, 3.5)))
reveal_type(shin.calculate_implied_probabilities({"a": 1.5, "b": 2.5, "c": 3.5}))
reveal_type(shin.calculate_implied_probabilities((1.5, 2.5, 3.5), full_output=True))
reveal_type(shin.calculate_implied_probabilities({"a": 1.5, "b": 2.5, "c": 3.5}, full_output=True))

with 0.1.1 branch:

src/domain/utils.py:22: note: Revealed type is "Union[builtins.dict[builtins.str, Any], builtins.list[builtins.float], builtins.dict[Any, builtins.float]]"
src/domain/utils.py:23: note: Revealed type is "Union[builtins.dict[builtins.str, Any], builtins.list[builtins.float], builtins.dict[Any, builtins.float]]"
src/domain/utils.py:24: note: Revealed type is "Union[builtins.dict[builtins.str, Any], builtins.list[builtins.float], builtins.dict[Any, builtins.float]]"
src/domain/utils.py:25: note: Revealed type is "Union[builtins.dict[builtins.str, Any], builtins.list[builtins.float], builtins.dict[Any, builtins.float]]"

Using overloads, that can be:

src/domain/utils.py:22: note: Revealed type is "builtins.list[builtins.float]"
src/domain/utils.py:23: note: Revealed type is "builtins.dict[Any, builtins.float]"
src/domain/utils.py:24: note: Revealed type is "builtins.dict[builtins.str, Any]"
src/domain/utils.py:25: note: Revealed type is "builtins.dict[builtins.str, Any]"

... which is a much nicer experience downstream as we don't need to narrow the return type somehow before we go on to use it.

LMK if this is something you'd be interested in a PR for.. there is potentially a lot of overloads required as all of the arguments can be specified either positionally or by kwarg and mypy requires an overload that covers any scenarios. E.g., here is the diff that I've got to produce the above:

(.venv) peter@pop-os:~/PycharmProjects/shin$ git diff
diff --git a/python/shin/__init__.py b/python/shin/__init__.py
index 2ba4c93..0d58daa 100644
--- a/python/shin/__init__.py
+++ b/python/shin/__init__.py
@@ -1,7 +1,7 @@
-from collections.abc import Collection
+from collections.abc import Sequence
 from collections.abc import Mapping
 from math import sqrt
-from typing import Any, Union
+from typing import Any, Literal, TypeVar, Union, overload
 
 
 from .shin import optimise as _optimise_rust
@@ -15,7 +15,7 @@ def _optimise(
     convergence_threshold: float = 1e-12,
 ) -> tuple[float, float, float]:
     delta = float("Inf")
-    z = 0
+    z = 0.0
     iterations = 0
     while delta > convergence_threshold and iterations < max_iterations:
         z0 = z
@@ -31,8 +31,85 @@ def _optimise(
     return z, delta, iterations
 
 
+# full output False as positional argument
+# sequence input
+@overload
 def calculate_implied_probabilities(
-    odds: Union[Collection[float], Mapping[Any, float]],
+    odds: Sequence[float],
+    max_iterations: int,
+    convergence_threshold: float,
+    full_output: Literal[False],
+    force_python_optimiser: bool = ...,
+) -> list[float]:
+    ...
+
+
+# mapping input
+@overload
+def calculate_implied_probabilities(
+    odds: Mapping[Any, float],
+    max_iterations: int,
+    convergence_threshold: float,
+    full_output: Literal[False],
+    force_python_optimiser: bool = ...,
+) -> dict[Any, float]:
+    ...
+
+
+# full output False as keyword argument, or default False
+# sequence input
+@overload
+def calculate_implied_probabilities(
+    odds: Sequence[float],
+    *,
+    max_iterations: int = 1000,
+    convergence_threshold: float = 1e-12,
+    full_output: Literal[False] = False,
+    force_python_optimiser: bool = False,
+) -> list[float]:
+    ...
+
+
+# mapping input
+@overload
+def calculate_implied_probabilities(
+    odds: Mapping[Any, float],
+    *,
+    max_iterations: int = 1000,
+    convergence_threshold: float = 1e-12,
+    full_output: Literal[False] = False,
+    force_python_optimiser: bool = False,
+) -> dict[Any, float]:
+    ...
+
+
+# full output True as positional argument
+@overload
+def calculate_implied_probabilities(
+    odds: Union[Sequence[float], Mapping[Any, float]],
+    max_iterations: int,
+    convergence_threshold: float,
+    full_output: Literal[True],
+    force_python_optimiser: bool = ...,
+) -> dict[str, Any]:
+    ...
+
+
+# full output True as keyword argument
+@overload
+def calculate_implied_probabilities(
+    odds: Union[Sequence[float], Mapping[Any, float]],
+    *,
+    max_iterations: int = 1000,
+    convergence_threshold: float = 1e-12,
+    full_output: Literal[True],
+    force_python_optimiser: bool = False,
+) -> dict[str, Any]:
+    ...
+
+
+def calculate_implied_probabilities(
+    odds: Union[Sequence[float], Mapping[Any, float]],
     max_iterations: int = 1000,
     convergence_threshold: float = 1e-12,
     full_output: bool = False,
diff --git a/python/shin/shin.pyi b/python/shin/shin.pyi
index e69de29..1ee3b6f 100644
--- a/python/shin/shin.pyi
+++ b/python/shin/shin.pyi
@@ -0,0 +1,8 @@
+def optimise(
+    inverse_odds: list[float],
+    sum_inverse_odds: float,
+    n: int,
+    max_iterations: int = 1000,
+    convergence_threshold: float = 1e-12,
+) -> tuple[float, float, float]:
+    ...

The number of overloads could be reduced by making the args other than odds keyword only, which would be an OK fit IMO:

def calculate_implied_probabilities(
    odds: Union[Sequence[float], Mapping[Any, float]],
    *,
    max_iterations: int = 1000,
    convergence_threshold: float = 1e-12,
    full_output: bool = False,
    force_python_optimiser: bool = False,
) -> Union[dict[str, Any], list[float], dict[Any, float]]:
@mberk
Copy link
Owner

mberk commented Jan 30, 2024

Sounds good - might I suggest as per discussion on #2 that this current issue widens in scope to include all of the changes you feel are appropriate to get the package in good shape for mypy including incorporating it into CI

@peterschutt
Copy link
Contributor Author

Great, thanks. I will open a PR for this stuff over the next day or so.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants