Skip to content

Commit

Permalink
Strict typing for all files
Browse files Browse the repository at this point in the history
Summary: Change files (except `models.py`) with `# pyre-unsafe` to `# pyre-strict` and fix the resulting typing issues.  The main issue was that `decorators.py` was missing annotations which resulted in loss of type information on decorated functions. Fixing this will allow safer refactoring.

Reviewed By: alexblanck

Differential Revision: D66972001

fbshipit-source-id: daceec269f991cdd7d73ca35804d7322d4ddd152
  • Loading branch information
scottblaha authored and facebook-github-bot committed Dec 13, 2024
1 parent 0106793 commit b1bb482
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 132 deletions.
3 changes: 2 additions & 1 deletion sapp/db_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sqlalchemy.dialects.mysql import BIGINT
from sqlalchemy.engine import Dialect
from sqlalchemy.orm import Session
from typing_extensions import Self

from .db import DB
from .iterutil import inclusive_range, split_every
Expand Down Expand Up @@ -346,7 +347,7 @@ def reserve(
# `typing.Type` to avoid runtime subscripting errors.
saving_classes: List[Type],
item_counts: Optional[Dict[str, int]] = None,
) -> "PrimaryKeyGeneratorBase":
) -> Self:
"""
session - Session for DB operations.
saving_classes - class objects that need to be saved e.g. Issue, Run
Expand Down
30 changes: 15 additions & 15 deletions sapp/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,32 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# pyre-strict

import datetime
import logging
import sys
import time
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, List, Optional
from typing import Callable, Generator, List, Optional, ParamSpec, Type, TypeVar

log = logging.getLogger("sapp")
log: logging.Logger = logging.getLogger("sapp")

P = ParamSpec("P")
R = TypeVar("R")


class retryable:
def __init__(
self, num_tries: int = 1, retryable_exs: Optional[List[Any]] = None
self, num_tries: int = 1, retryable_exs: Optional[List[Type[Exception]]] = None
) -> None:
self.num_tries = num_tries
self.retryable_exs = retryable_exs

def __call__(self, func):
def __call__(self, func: Callable[P, R]) -> Callable[P, R]:
@wraps(func)
def new_func(*args, **kwargs):
def new_func(*args: P.args, **kwargs: P.kwargs) -> R:
try_num = 1
while True:
try:
Expand All @@ -41,12 +44,12 @@ def new_func(*args, **kwargs):
return new_func


def log_time(func: Callable[..., Any]) -> Callable[..., Any]:
def log_time(func: Callable[P, R]) -> Callable[P, R]:
"""Log the time it takes to run a function. It's sort of like timeit, but
prettier.
"""

def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.time()
log.info("%s starting...", func.__name__.title())
ret = func(*args, **kwargs)
Expand All @@ -65,25 +68,22 @@ class UserError(Exception):


@contextmanager
def catch_user_error():
def catch_user_error() -> Generator[None, None, None]:
try:
yield
except UserError as error:
print(str(error), file=sys.stderr)


@contextmanager
def catch_keyboard_interrupt():
def catch_keyboard_interrupt() -> Generator[None, None, None]:
try:
yield
except KeyboardInterrupt:
print("\nOperation aborted.", file=sys.stderr)


# For use on enums to alias upper case value.
#
# FLAKE8 does not understand that this is a static property
# flake8: noqa B902
class classproperty(property):
def __get__(self, cls, owner):
return classmethod(self.fget).__get__(None, owner)()
def __get__(self, cls: object, owner: Optional[Type[R]]) -> R:
return classmethod(self.fget or (lambda _: None)).__get__(None, owner)()
15 changes: 10 additions & 5 deletions sapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ class SharedText(Base, PrepareMixin, RecordMixin):
nullable=False,
)

kind: Column[str] = Column(
# pyre-fixme[8]: Attribute has type `Column[str]`; used as `Column[SharedTextKind]`.
kind: Column[SharedTextKind] = Column(
Enum(SharedTextKind), server_default="feature", nullable=False
)

Expand Down Expand Up @@ -930,7 +931,8 @@ class Run(Base):
viewonly=True,
)

status: Column[str] = Column(
# pyre-fixme[8]: Attribute has type `Column[str]`; used as `Column[RunStatus]`.
status: Column[RunStatus] = Column(
Enum(RunStatus), server_default="finished", nullable=False, index=True
)

Expand Down Expand Up @@ -962,7 +964,8 @@ class Run(Base):
server_default="0",
)

purge_status: Column[str] = Column(
# pyre-fixme[8]: Attribute has type `Column[str]`; used as `Column[PurgeStatus]`.
purge_status: Column[PurgeStatus] = Column(
Enum(PurgeStatus),
server_default="unpurged",
nullable=False,
Expand Down Expand Up @@ -1058,7 +1061,8 @@ class MetaRun(Base):
default=CURRENT_DB_VERSION,
)

status: Column[str] = Column(
# pyre-fixme[8]: Attribute has type `Column[str]`; used as `Column[RunStatus]`.
status: Column[RunStatus] = Column(
Enum(RunStatus), server_default="finished", nullable=False, index=True
)

Expand Down Expand Up @@ -1262,7 +1266,8 @@ class TraceFrame(Base, PrepareMixin, RecordMixin):
server_default="",
)

reachability: Column[str] = Column(
# pyre-fixme[8]: Attribute has type `Column[str]`; used as `Column[FrameReachability]`.
reachability: Column[FrameReachability] = Column(
Enum(FrameReachability),
server_default="unreachable",
nullable=False,
Expand Down
2 changes: 1 addition & 1 deletion sapp/pipeline/database_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
# pyre-fixme[13]: Attribute `summary` is never initialized.
self.summary: Summary

@log_time
@log_time # pyre-ignore[56]: Pyre can't support this yet.
def run(
self, input: List[TraceGraph], summary: Summary
) -> Tuple[RunSummary, Summary]:
Expand Down
2 changes: 1 addition & 1 deletion sapp/pipeline/model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _create_empty_run(
job_id=self.summary["job_id"],
issue_instances=[],
date=datetime.datetime.now(),
status=status.name,
status=status,
status_description=status_description,
repository=self.summary["repository"],
branch=self.summary["branch"],
Expand Down
Loading

0 comments on commit b1bb482

Please sign in to comment.