diff --git a/sapp/db_support.py b/sapp/db_support.py index d11625c..b2a60fa 100644 --- a/sapp/db_support.py +++ b/sapp/db_support.py @@ -17,6 +17,7 @@ from sqlalchemy.dialects.mysql import BIGINT from sqlalchemy.engine import Dialect from sqlalchemy.orm import Session +from typing_extensions import Self from .db import DB from .iterutil import inclusive_range, split_every @@ -346,7 +347,7 @@ def reserve( # `typing.Type` to avoid runtime subscripting errors. saving_classes: List[Type], item_counts: Optional[Dict[str, int]] = None, - ) -> "PrimaryKeyGeneratorBase": + ) -> Self: """ session - Session for DB operations. saving_classes - class objects that need to be saved e.g. Issue, Run diff --git a/sapp/decorators.py b/sapp/decorators.py index 83be91c..ffe0fad 100644 --- a/sapp/decorators.py +++ b/sapp/decorators.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import datetime import logging @@ -11,21 +11,24 @@ import time from contextlib import contextmanager from functools import wraps -from typing import Any, Callable, List, Optional +from typing import Callable, Generator, List, Optional, ParamSpec, Type, TypeVar -log = logging.getLogger("sapp") +log: logging.Logger = logging.getLogger("sapp") + +P = ParamSpec("P") +R = TypeVar("R") class retryable: def __init__( - self, num_tries: int = 1, retryable_exs: Optional[List[Any]] = None + self, num_tries: int = 1, retryable_exs: Optional[List[Type[Exception]]] = None ) -> None: self.num_tries = num_tries self.retryable_exs = retryable_exs - def __call__(self, func): + def __call__(self, func: Callable[P, R]) -> Callable[P, R]: @wraps(func) - def new_func(*args, **kwargs): + def new_func(*args: P.args, **kwargs: P.kwargs) -> R: try_num = 1 while True: try: @@ -41,12 +44,12 @@ def new_func(*args, **kwargs): return new_func -def log_time(func: Callable[..., Any]) -> Callable[..., Any]: +def log_time(func: Callable[P, R]) -> Callable[P, R]: """Log the time it takes to run a function. It's sort of like timeit, but prettier. """ - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: start_time = time.time() log.info("%s starting...", func.__name__.title()) ret = func(*args, **kwargs) @@ -65,7 +68,7 @@ class UserError(Exception): @contextmanager -def catch_user_error(): +def catch_user_error() -> Generator[None, None, None]: try: yield except UserError as error: @@ -73,7 +76,7 @@ def catch_user_error(): @contextmanager -def catch_keyboard_interrupt(): +def catch_keyboard_interrupt() -> Generator[None, None, None]: try: yield except KeyboardInterrupt: @@ -81,9 +84,6 @@ def catch_keyboard_interrupt(): # For use on enums to alias upper case value. -# -# FLAKE8 does not understand that this is a static property -# flake8: noqa B902 class classproperty(property): - def __get__(self, cls, owner): - return classmethod(self.fget).__get__(None, owner)() + def __get__(self, cls: object, owner: Optional[Type[R]]) -> R: + return classmethod(self.fget or (lambda _: None)).__get__(None, owner)() diff --git a/sapp/models.py b/sapp/models.py index 07b4cbd..a3b2b09 100644 --- a/sapp/models.py +++ b/sapp/models.py @@ -279,7 +279,8 @@ class SharedText(Base, PrepareMixin, RecordMixin): nullable=False, ) - kind: Column[str] = Column( + # pyre-fixme[8]: Attribute has type `Column[str]`; used as `Column[SharedTextKind]`. + kind: Column[SharedTextKind] = Column( Enum(SharedTextKind), server_default="feature", nullable=False ) @@ -930,7 +931,8 @@ class Run(Base): viewonly=True, ) - status: Column[str] = Column( + # pyre-fixme[8]: Attribute has type `Column[str]`; used as `Column[RunStatus]`. + status: Column[RunStatus] = Column( Enum(RunStatus), server_default="finished", nullable=False, index=True ) @@ -962,7 +964,8 @@ class Run(Base): server_default="0", ) - purge_status: Column[str] = Column( + # pyre-fixme[8]: Attribute has type `Column[str]`; used as `Column[PurgeStatus]`. + purge_status: Column[PurgeStatus] = Column( Enum(PurgeStatus), server_default="unpurged", nullable=False, @@ -1058,7 +1061,8 @@ class MetaRun(Base): default=CURRENT_DB_VERSION, ) - status: Column[str] = Column( + # pyre-fixme[8]: Attribute has type `Column[str]`; used as `Column[RunStatus]`. + status: Column[RunStatus] = Column( Enum(RunStatus), server_default="finished", nullable=False, index=True ) @@ -1262,7 +1266,8 @@ class TraceFrame(Base, PrepareMixin, RecordMixin): server_default="", ) - reachability: Column[str] = Column( + # pyre-fixme[8]: Attribute has type `Column[str]`; used as `Column[FrameReachability]`. + reachability: Column[FrameReachability] = Column( Enum(FrameReachability), server_default="unreachable", nullable=False, diff --git a/sapp/pipeline/database_saver.py b/sapp/pipeline/database_saver.py index f983b18..960cc32 100644 --- a/sapp/pipeline/database_saver.py +++ b/sapp/pipeline/database_saver.py @@ -61,7 +61,7 @@ def __init__( # pyre-fixme[13]: Attribute `summary` is never initialized. self.summary: Summary - @log_time + @log_time # pyre-ignore[56]: Pyre can't support this yet. def run( self, input: List[TraceGraph], summary: Summary ) -> Tuple[RunSummary, Summary]: diff --git a/sapp/pipeline/model_generator.py b/sapp/pipeline/model_generator.py index b5994f9..d1192ab 100644 --- a/sapp/pipeline/model_generator.py +++ b/sapp/pipeline/model_generator.py @@ -159,7 +159,7 @@ def _create_empty_run( job_id=self.summary["job_id"], issue_instances=[], date=datetime.datetime.now(), - status=status.name, + status=status, status_description=status_description, repository=self.summary["repository"], branch=self.summary["branch"], diff --git a/sapp/tests/fake_object_generator.py b/sapp/tests/fake_object_generator.py index 96b8f60..c6330a3 100644 --- a/sapp/tests/fake_object_generator.py +++ b/sapp/tests/fake_object_generator.py @@ -4,13 +4,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import datetime -from typing import Callable, Optional +from typing import Callable, cast, List, Optional, Tuple from ..bulk_saver import BulkSaver - +from ..db import DB from ..models import ( ClassTypeInterval, DBID, @@ -37,15 +37,18 @@ def __init__(self, graph: Optional[TraceGraph] = None, run_id: int = 0) -> None: self.reinit(run_id) self.graph = graph - def reinit(self, run_id) -> None: + def reinit(self, run_id: int) -> None: self.saver = BulkSaver() self.handle = 0 self.source_name_id = 0 self.sink_name_id = 0 self.shared_text_name_id = 0 self.run_id = run_id + self.metarun_id = 0 - def save_all(self, db, before_save: Optional[Callable[[], None]] = None) -> None: + def save_all( + self, db: DB, before_save: Optional[Callable[[], None]] = None + ) -> None: if self.graph: self.graph.update_bulk_saver(self.saver) self.saver.prepare_all(db) @@ -57,26 +60,28 @@ def save_all(self, db, before_save: Optional[Callable[[], None]] = None) -> None def issue( self, callable: str = "Foo.barMethod", - handle=None, - code=None, + handle: Optional[str] = None, + code: Optional[int] = None, status: str = "uncategorized", - ): + ) -> Issue: self.handle += 1 now = datetime.datetime.now() - callable = self.callable(callable) - result = Issue.Record( - id=IssueDBID(), - handle=str(self.handle) if not handle else handle, - code=code or (6015 + self.handle), - callable_id=callable.id, - status=status, - detected_time=int(now.timestamp()), - first_instance_id=DBID(10072), - update_time=0, - triage_duration=0, + callable_shared_text = self.callable(callable) + result = cast( + Issue, + Issue.Record( + id=IssueDBID(), + handle=str(self.handle) if not handle else handle, + code=code or (6015 + self.handle), + callable_id=callable_shared_text.id, + status=status, + detected_time=int(now.timestamp()), + first_instance_id=DBID(10072), + update_time=0, + triage_duration=0, + ), ) if self.graph: - # pyre-fixme[6]: For 1st param expected `Issue` but got `Munch`. self.graph.add_issue(result) else: self.saver.add(result) @@ -89,14 +94,14 @@ def precondition( callee: str = "triple_meh", callee_port: str = "at the beginning of time", filename: str = "lib/server/posts/request.py", - location=(4, 5, 6), - leaves=None, - reachability=FrameReachability.UNREACHABLE, + location: Tuple[int, ...] = (4, 5, 6), + leaves: Optional[List[Tuple[SharedText, object]]] = None, + reachability: FrameReachability = FrameReachability.UNREACHABLE, preserves_type_context: bool = False, type_interval_lower: int = 5, type_interval_upper: int = 7, - run_id=None, - ): + run_id: Optional[int] = None, + ) -> TraceFrame: leaves = leaves or [] filename_record = self.filename(filename) caller_record = self.callable(caller) @@ -141,13 +146,13 @@ def postcondition( callee: str = "quintuple_meh", callee_port: str = "callee_meh", filename: str = "lib/server/posts/response.py", - location=(4, 5, 6), - leaves=None, + location: Tuple[int, ...] = (4, 5, 6), + leaves: Optional[List[Tuple[SharedText, object]]] = None, preserves_type_context: bool = False, type_interval_lower: int = 5, type_interval_upper: int = 7, - run_id=None, - ): + run_id: Optional[int] = None, + ) -> TraceFrame: leaves = leaves or [] filename_record = self.filename(filename) caller_record = self.callable(caller) @@ -185,7 +190,7 @@ def postcondition( self.saver.add(trace_frame) return trace_frame - def shared_text(self, contents, kind): + def shared_text(self, contents: str, kind: SharedTextKind) -> SharedText: if self.graph: shared_text = self.graph.get_shared_text(kind, contents) if shared_text is not None: @@ -198,7 +203,12 @@ def shared_text(self, contents, kind): self.saver.add(result) return result - def run(self, differential_id=None, job_id=None, kind=None): + def run( + self, + differential_id: Optional[int] = None, + job_id: Optional[str] = None, + kind: Optional[str] = None, + ) -> Run: self.run_id += 1 # Not added to bulksaver or graph return Run( @@ -211,7 +221,9 @@ def run(self, differential_id=None, job_id=None, kind=None): kind=kind, ) - def metarun(self, status=RunStatus.FINISHED, kind="test_metarun"): + def metarun( + self, status: RunStatus = RunStatus.FINISHED, kind: str = "test_metarun" + ) -> MetaRun: self.metarun_id += 1 # Not added to bulksaver or graph return MetaRun( @@ -221,28 +233,28 @@ def metarun(self, status=RunStatus.FINISHED, kind="test_metarun"): status=status, ) - def feature(self, name: str = "via:feature"): + def feature(self, name: str = "via:feature") -> SharedText: return self.shared_text(contents=name, kind=SharedTextKind.FEATURE) - def source(self, name: str = "source"): + def source(self, name: str = "source") -> SharedText: return self.shared_text(contents=name, kind=SharedTextKind.SOURCE) - def source_detail(self, name: str = "source_detail"): + def source_detail(self, name: str = "source_detail") -> SharedText: return self.shared_text(contents=name, kind=SharedTextKind.SOURCE_DETAIL) - def sink(self, name: str = "sink"): + def sink(self, name: str = "sink") -> SharedText: return self.shared_text(contents=name, kind=SharedTextKind.SINK) - def sink_detail(self, name: str = "sink_detail"): + def sink_detail(self, name: str = "sink_detail") -> SharedText: return self.shared_text(contents=name, kind=SharedTextKind.SINK_DETAIL) - def filename(self, name: str = "/r/some/filename.py"): + def filename(self, name: str = "/r/some/filename.py") -> SharedText: return self.shared_text(contents=name, kind=SharedTextKind.FILENAME) - def callable(self, name: str = "Foo.barMethod"): + def callable(self, name: str = "Foo.barMethod") -> SharedText: return self.shared_text(contents=name, kind=SharedTextKind.CALLABLE) - def message(self, name: str = "this is bad"): + def message(self, name: str = "this is bad") -> SharedText: return self.shared_text(contents=name, kind=SharedTextKind.MESSAGE) def instance( @@ -250,41 +262,43 @@ def instance( message: str = "this is bad", filename: str = "/r/some/filename.py", callable: str = "Foo.barMethod", - issue_id=None, - min_trace_length_to_sources=None, - min_trace_length_to_sinks=None, - purge_status=PurgeStatusForInstance.none, - run_id=None, - archive_if_new_issue=True, - ): + issue_id: Optional[DBID] = None, + min_trace_length_to_sources: Optional[int] = None, + min_trace_length_to_sinks: Optional[int] = None, + purge_status: PurgeStatusForInstance = PurgeStatusForInstance.none, + run_id: Optional[int] = None, + archive_if_new_issue: bool = True, + ) -> IssueInstance: issue_id = issue_id if issue_id is not None else DBID(1) - filename = self.filename(filename) - message = self.message(message) - callable = self.callable(callable) - result = IssueInstance.Record( - id=DBID(), - location=SourceLocation(6, 7, 8), - filename_id=filename.id, - message_id=message.id, - callable_id=callable.id, - run_id=run_id or self.run_id, - issue_id=issue_id, - min_trace_length_to_sources=min_trace_length_to_sources, - min_trace_length_to_sinks=min_trace_length_to_sinks, - purge_status=purge_status, - archive_if_new_issue=archive_if_new_issue, + filename_shared_text = self.filename(filename) + message_shared_text = self.message(message) + callable_shared_text = self.callable(callable) + result = cast( + IssueInstance, + IssueInstance.Record( + id=DBID(), + location=SourceLocation(6, 7, 8), + filename_id=filename_shared_text.id, + message_id=message_shared_text.id, + callable_id=callable_shared_text.id, + run_id=run_id or self.run_id, + issue_id=issue_id, + min_trace_length_to_sources=min_trace_length_to_sources, + min_trace_length_to_sinks=min_trace_length_to_sinks, + purge_status=purge_status, + archive_if_new_issue=archive_if_new_issue, + ), ) if self.graph: - # pyre-fixme[6]: For 1st param expected `IssueInstance` but got `Munch`. self.graph.add_issue_instance(result) else: self.saver.add(result) return result - def fix_info(self): + def fix_info(self) -> IssueInstanceFixInfo: result = IssueInstanceFixInfo.Record(id=DBID(), fix_info="fixthis") if self.graph: - self.graph.add_fix_info(result) + self.graph.add_issue_instance_fix_info(result.issue_instance, result) else: self.saver.add(result) return result @@ -294,7 +308,7 @@ def class_type_interval( class_name: str = "\\Foo", lower_bound: int = 0, upper_bound: int = 100, - run_id=None, + run_id: Optional[int] = None, ) -> ClassTypeInterval: interval = ClassTypeInterval.Record( id=DBID(), diff --git a/sapp/trace_graph.py b/sapp/trace_graph.py index 70de3fb..57b559c 100644 --- a/sapp/trace_graph.py +++ b/sapp/trace_graph.py @@ -260,9 +260,9 @@ def add_shared_text(self, shared_text: SharedText) -> None: # Allow look up of SharedTexts by name and kind (to optimize # get_shared_text which is called when parsing each issue instance) - self._shared_text_lookup[ - SharedTextKind.from_string_with_exception(shared_text.kind) - ][shared_text.contents] = shared_text.id.local_id + self._shared_text_lookup[shared_text.kind][shared_text.contents] = ( + shared_text.id.local_id + ) def get_or_add_shared_text(self, kind: SharedTextKind, name: str) -> SharedText: name = name[:SHARED_TEXT_LENGTH] @@ -523,9 +523,7 @@ def get_transform_normalized_caller_kind_id(self, leaf_kind: SharedText) -> int: ) if "@" in leaf_kind.contents or "!" in leaf_kind.contents: normal_name = self.get_transform_normalized_caller_kind(leaf_kind.contents) - normal_kind = self.get_or_add_shared_text( - SharedTextKind.from_string_with_exception(leaf_kind.kind), normal_name - ) + normal_kind = self.get_or_add_shared_text(leaf_kind.kind, normal_name) return normal_kind.id.local_id else: return leaf_kind.id.local_id @@ -544,9 +542,7 @@ def get_transformed_callee_kind_id(self, leaf_kind: SharedText) -> int: ) if "@" in leaf_kind.contents or "!" in leaf_kind.contents: rest = self.get_transformed_callee_kind(leaf_kind.contents) - remaining_kind = self.get_or_add_shared_text( - SharedTextKind.from_string_with_exception(leaf_kind.kind), rest - ) + remaining_kind = self.get_or_add_shared_text(leaf_kind.kind, rest) return remaining_kind.id.local_id else: return leaf_kind.id.local_id diff --git a/sapp/ui/interactive.py b/sapp/ui/interactive.py index f576f07..bba3276 100644 --- a/sapp/ui/interactive.py +++ b/sapp/ui/interactive.py @@ -15,7 +15,6 @@ from collections import defaultdict from typing import ( Any, - Callable, DefaultDict, Dict, Iterable, @@ -88,7 +87,7 @@ class LeafOrderBy(str, enum.Enum): number_issues = "number_issues" -ScopeVariables = Dict[str, Union[Callable[..., None], TraceKind]] +ScopeVariables = Dict[str, object] class Interactive: diff --git a/sapp/ui/tests/interactive_test.py b/sapp/ui/tests/interactive_test.py index c474a1f..8898072 100644 --- a/sapp/ui/tests/interactive_test.py +++ b/sapp/ui/tests/interactive_test.py @@ -4,13 +4,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import os import sys from datetime import datetime from io import StringIO -from typing import List +from typing import cast, List, Union from unittest import TestCase from unittest.mock import mock_open, patch @@ -64,7 +64,11 @@ def _clear_stdout(self) -> None: self.stdout = StringIO() sys.stdout = self.stdout - def _add_to_session(self, session, data) -> None: + def _add_to_session( + self, + session: Session, + data: Union[List[IssueInstanceSharedTextAssoc], List[Run], List[SharedText]], + ) -> None: if not isinstance(data, list): session.add(data) return @@ -206,6 +210,7 @@ def testListIssuesFilterCodes(self) -> None: self._list_issues_filter_setup() self.interactive.setup() + # pyre-ignore[6]: Intentional wrong type for testing. self.interactive.issues(codes="a string") stderr = self.stderr.getvalue().strip() self.assertIn("'codes' should be", stderr) @@ -227,6 +232,7 @@ def testListIssuesFilterCallables(self) -> None: self._list_issues_filter_setup() self.interactive.setup() + # pyre-ignore[6]: Intentional wrong type for testing. self.interactive.issues(callables=1234) stderr = self.stderr.getvalue().strip() self.assertIn("'callables' should be", stderr) @@ -248,6 +254,7 @@ def testListIssuesFilterFilenames(self) -> None: self._list_issues_filter_setup() self.interactive.setup() + # pyre-ignore[6]: Intentional wrong type for testing. self.interactive.issues(filenames=1234) stderr = self.stderr.getvalue().strip() self.assertIn("'filenames' should be", stderr) @@ -270,21 +277,25 @@ def testListIssuesFilterMinTraceLength(self) -> None: self.interactive.setup() + # pyre-ignore[6]: Intentional wrong type for testing. self.interactive.issues(exact_trace_length_to_sources="1") stderr = self.stderr.getvalue().strip() self.assertIn("'exact_trace_length_to_sources' should be", stderr) self._clear_stdout() + # pyre-ignore[6]: Intentional wrong type for testing. self.interactive.issues(exact_trace_length_to_sinks="1") stderr = self.stderr.getvalue().strip() self.assertIn("'exact_trace_length_to_sinks' should be", stderr) self._clear_stdout() + # pyre-ignore[6]: Intentional wrong type for testing. self.interactive.issues(max_trace_length_to_sources="1") stderr = self.stderr.getvalue().strip() self.assertIn("'max_trace_length_to_sources' should be", stderr) self._clear_stdout() + # pyre-ignore[6]: Intentional wrong type for testing. self.interactive.issues(max_trace_length_to_sinks="1") stderr = self.stderr.getvalue().strip() self.assertIn("'max_trace_length_to_sinks' should be", stderr) @@ -538,6 +549,7 @@ def testListIssuesFilterStatuses(self) -> None: self._list_issues_filter_setup() self.interactive.setup() + # pyre-ignore[6]: Intentional wrong type for testing. self.interactive.issues(statuses=1234) stderr = self.stderr.getvalue().strip() self.assertIn("'statuses' should be", stderr) @@ -588,7 +600,7 @@ def testSetRun(self) -> None: session.commit() self.interactive.setup() - self.interactive.run(1) + self.interactive.run(cast(DBID, 1)) self.interactive.issues() output = self.stdout.getvalue().strip() @@ -606,8 +618,8 @@ def testSetRunNonExistent(self) -> None: session.commit() self.interactive.setup() - self.interactive.run(2) - self.interactive.run(3) + self.interactive.run(cast(DBID, 2)) + self.interactive.run(cast(DBID, 3)) stderr = self.stderr.getvalue().strip() self.assertIn("Run 2 doesn't exist", stderr) @@ -654,14 +666,14 @@ def testSetIssue(self) -> None: self.interactive.setup() - self.interactive.issue(2) + self.interactive.issue(cast(DBID, 2)) self.assertEqual(int(self.interactive.current_issue_instance_id), 2) stdout = self.stdout.getvalue().strip() self.assertNotIn("Issue 1", stdout) self.assertIn("Issue 2", stdout) self.assertNotIn("Issue 3", stdout) - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) self.assertEqual(int(self.interactive.current_issue_instance_id), 1) stdout = self.stdout.getvalue().strip() self.assertIn("Issue 1", stdout) @@ -675,7 +687,7 @@ def testSetIssueNonExistent(self) -> None: session.commit() self.interactive.setup() - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) stderr = self.stderr.getvalue().strip() self.assertIn("Issue 1 doesn't exist", stderr) @@ -697,7 +709,7 @@ def testSetIssueUpdatesRun(self) -> None: self.interactive.setup() self.assertEqual(int(self.interactive._current_run_id), 2) - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) self.assertEqual(int(self.interactive._current_run_id), 1) def testGetSources(self) -> None: @@ -793,7 +805,7 @@ def testGetFeatures(self) -> None: self.assertIn("via:feature1", features) self.assertIn("via:feature2", features) - def _basic_trace_frames(self): + def _basic_trace_frames(self) -> List[TraceFrame]: return [ self.fakes.precondition( caller="call1", @@ -1083,7 +1095,7 @@ def testTraceFromIssue(self) -> None: stderr = self.stderr.getvalue().strip() self.assertIn("Use 'issue ID' or 'frame ID'", stderr) - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) self._clear_stdout() self.interactive.trace() self.assertEqual( @@ -1182,7 +1194,7 @@ def testTraceMissingFrames(self) -> None: session.commit() self.interactive.setup() - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) self.interactive.trace() stdout = self.stdout.getvalue().strip() self.assertIn("Missing trace frame: call2:param0", stdout) @@ -1238,7 +1250,7 @@ def testTraceCursorLocation(self) -> None: self.assertIsNone(self.interactive.callable()) - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) self.assertEqual(self.interactive.callable(), "Issue callable") self.assertEqual(self.interactive.current_trace_frame_index, 1) @@ -1295,7 +1307,7 @@ def testJumpToLocation(self) -> None: session.commit() self.interactive.setup() - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) self.assertEqual(self.interactive.current_trace_frame_index, 1) self.interactive.jump(1) @@ -1337,7 +1349,7 @@ def testTraceNoSinks(self) -> None: self.interactive.setup() self.interactive.sources = {"source1"} - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) self._clear_stdout() self.interactive.trace() self.assertEqual( @@ -1438,7 +1450,7 @@ def testTraceBranchNumber(self) -> None: self._set_up_branched_trace() self.interactive.setup() - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) self.assertEqual(self.interactive.sources, {"source1"}) self.assertEqual(self.interactive.sinks, {"sink1"}) @@ -1461,7 +1473,7 @@ def testShowBranches(self) -> None: self._set_up_branched_trace() self.interactive.setup() - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) # Parent at root self.interactive.prev_cursor_location() with patch("click.prompt", return_value=0): @@ -1523,7 +1535,7 @@ def testGetTraceFrameBranches(self) -> None: frames = self._set_up_branched_trace() self.interactive.setup() - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) # Parent at root self.interactive.prev_cursor_location() @@ -1547,7 +1559,7 @@ def testBranch(self) -> None: self._set_up_branched_trace() self.interactive.setup() - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) self.interactive.prev_cursor_location() # We are testing for the source location, which differs between branches @@ -1664,7 +1676,7 @@ def testBranchPrefixLengthChanges(self) -> None: session.commit() self.interactive.setup() - self.interactive.issue(1) + self.interactive.issue(cast(DBID, 1)) self._clear_stdout() self.interactive.prev_cursor_location() @@ -1878,7 +1890,7 @@ def testCreateIssueOutputStringNoSourcesNoSinks(self) -> None: # pyre-fixme[6]: For 15th param expected `FrozenSet[str]` but got # `List[str]`. sink_kinds=["sink1", "sink2"], - status=IssueStatus.UNCATEGORIZED, + status=IssueStatus.UNCATEGORIZED.name, detected_time=datetime.today(), # pyre-fixme[6]: For 18th param expected `Set[SimilarIssue]` but got # `Set[Tuple[int, str]]`. @@ -1945,7 +1957,7 @@ def testCreateIssueOutputStringNoFeatures(self) -> None: # pyre-fixme[6]: For 15th param expected `FrozenSet[str]` but got # `List[str]`. sink_kinds=["sink1"], - status=IssueStatus.UNCATEGORIZED, + status=cast(str, IssueStatus.UNCATEGORIZED), detected_time=datetime.today(), # pyre-fixme[6]: For 18th param expected `Set[SimilarIssue]` but got # `Set[Tuple[int, str]]`. @@ -2010,7 +2022,7 @@ def testCreateIssueOutputStringTraceLength(self) -> None: # pyre-fixme[6]: For 15th param expected `FrozenSet[str]` but got # `List[str]`. sink_kinds=["sink1", "sink2"], - status=IssueStatus.UNCATEGORIZED, + status=cast(str, IssueStatus.UNCATEGORIZED), detected_time=datetime.today(), # pyre-fixme[6]: For 18th param expected `Set[SimilarIssue]` but got # `Set[Tuple[int, str]]`. @@ -2059,7 +2071,7 @@ def testCreateIssueOutputStringTraceLength(self) -> None: # pyre-fixme[6]: For 15th param expected `FrozenSet[str]` but got # `List[str]`. sink_kinds=["sink1", "sink2"], - status=IssueStatus.UNCATEGORIZED, + status=cast(str, IssueStatus.UNCATEGORIZED), detected_time=datetime.today(), # pyre-fixme[6]: For 18th param expected `Set[SimilarIssue]` but got # `Set[Tuple[int, str]]`. @@ -2303,7 +2315,7 @@ def testListTracesFilterCallersCallees(self) -> None: def testListFramesWithLimit(self) -> None: frames = self._set_up_branched_trace() - self.interactive.run(1) + self.interactive.run(cast(DBID, 1)) self._clear_stdout() self.interactive.frames(limit=3) @@ -2567,9 +2579,9 @@ def testDetails(self) -> None: self.fakes.issue(callable="call3"), self.fakes.issue(callable="call2"), ] - (self.fakes.instance(issue_id=issues[0].id, callable="call2"),) - (self.fakes.instance(issue_id=issues[1].id, callable="call3"),) - (self.fakes.instance(issue_id=issues[2].id, callable="call2"),) + self.fakes.instance(issue_id=issues[0].id, callable="call2") + self.fakes.instance(issue_id=issues[1].id, callable="call3") + self.fakes.instance(issue_id=issues[2].id, callable="call2") self.fakes.save_all(self.db) with self.db.make_session(expire_on_commit=False) as session: @@ -2640,7 +2652,7 @@ def testListLeaves(self) -> None: self.assertIn("sink_detail_1", output) self.assertIn("sink_detail_2", output) - def mock_pager(self, output_string) -> None: + def mock_pager(self, output_string: str) -> None: # pyre-fixme[16]: `InteractiveTest` has no attribute `pager_calls`. self.pager_calls += 1 diff --git a/sapp/ui/tests/trace_test.py b/sapp/ui/tests/trace_test.py index 8c5faf6..c57e795 100644 --- a/sapp/ui/tests/trace_test.py +++ b/sapp/ui/tests/trace_test.py @@ -5,7 +5,7 @@ # pyre-strict -from typing import Any, List +from typing import Any, cast, List from unittest import TestCase from ...db import DB, DBType @@ -129,7 +129,11 @@ def testNextTraceFramesBackwards(self) -> None: session.commit() next_frames = trace_module.next_frames( - session, frames[1], {"sink1"}, set(), backwards=True + session, + cast(trace_module.TraceFrameQueryResult, frames[1]), + {"sink1"}, + set(), + backwards=True, ) self.assertEqual(len(next_frames), 1)