From 248117026a49e0365f3a16d8954be32ad063e29d Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 23 Aug 2024 14:49:08 -0400 Subject: [PATCH 01/20] Add type hints to rapids_pre_commit_hooks.lint And change the API of Linter.line_for_pos() and refactor its test. --- src/rapids_pre_commit_hooks/lint.py | 80 ++++++++++++----------- test/rapids_pre_commit_hooks/test_lint.py | 64 +++++++++++------- 2 files changed, 81 insertions(+), 63 deletions(-) diff --git a/src/rapids_pre_commit_hooks/lint.py b/src/rapids_pre_commit_hooks/lint.py index 16b9f6e..f901782 100644 --- a/src/rapids_pre_commit_hooks/lint.py +++ b/src/rapids_pre_commit_hooks/lint.py @@ -18,6 +18,7 @@ import functools import re import warnings +from typing import Callable, Generator, Iterable, Optional from rich.console import Console from rich.markup import escape @@ -26,7 +27,7 @@ # Taken from Python docs # (https://docs.python.org/3.12/library/itertools.html#itertools.pairwise) # Replace with itertools.pairwise after dropping Python 3.9 support -def _pairwise(iterable): +def _pairwise(iterable: Iterable) -> Generator: # pairwise('ABCDEFG') → AB BC CD DE EF FG iterator = iter(iterable) a = next(iterator, None) @@ -35,6 +36,9 @@ def _pairwise(iterable): a = b +_PosType = tuple[int, int] + + class OverlappingReplacementsError(RuntimeError): pass @@ -44,29 +48,29 @@ class BinaryFileWarning(Warning): class Replacement: - def __init__(self, pos, newtext): - self.pos = pos - self.newtext = newtext + def __init__(self, pos: _PosType, newtext: str): + self.pos: _PosType = pos + self.newtext: str = newtext - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Replacement): return False return self.pos == other.pos and self.newtext == other.newtext - def __repr__(self): + def __repr__(self) -> str: return f"Replacement(pos={self.pos}, newtext={repr(self.newtext)})" class LintWarning: - def __init__(self, pos, msg): - self.pos = pos - self.msg = msg - self.replacements = [] + def __init__(self, pos: _PosType, msg: str): + self.pos: _PosType = pos + self.msg: str = msg + self.replacements: list[Replacement] = [] - def add_replacement(self, pos, newtext): + def add_replacement(self, pos: _PosType, newtext: str): self.replacements.append(Replacement(pos, newtext)) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, LintWarning): return False return ( @@ -75,31 +79,31 @@ def __eq__(self, other): and self.replacements == other.replacements ) - def __repr__(self): + def __repr__(self) -> str: return ( "LintWarning(" - + f"pos={self.pos}, " - + f"msg={self.msg}, " - + f"replacements={self.replacements})" + f"pos={self.pos}, " + f"msg={self.msg}, " + f"replacements={self.replacements})" ) class Linter: NEWLINE_RE = re.compile("[\r\n]") - def __init__(self, filename, content): - self.filename = filename - self.content = content - self.warnings = [] - self.console = Console(highlight=False) + def __init__(self, filename: str, content: str): + self.filename: str = filename + self.content: str = content + self.warnings: list[LintWarning] = [] + self.console: Console = Console(highlight=False) self._calculate_lines() - def add_warning(self, pos, msg): + def add_warning(self, pos: _PosType, msg: str): w = LintWarning(pos, msg) self.warnings.append(w) return w - def fix(self): + def fix(self) -> str: sorted_replacements = sorted( ( replacement @@ -123,7 +127,7 @@ def fix(self): replaced_content += self.content[cursor:] return replaced_content - def print_warnings(self, fix_applied=False): + def print_warnings(self, fix_applied: bool = False): sorted_warnings = sorted(self.warnings, key=lambda warning: warning.pos) for warning in sorted_warnings: @@ -172,7 +176,7 @@ def print_warnings(self, fix_applied=False): self.console.print("[bold]note:[/bold] suggested fix") self.console.print() - def print_highlighted_code(self, pos, replacement=None): + def print_highlighted_code(self, pos: _PosType, replacement: Optional[str] = None): line_index = self.line_for_pos(pos[0]) line_pos = self.lines[line_index] left = pos[0] @@ -200,11 +204,11 @@ def print_highlighted_code(self, pos, replacement=None): f"{escape(self.content[right:line_pos[1]])}[/green]" ) - def line_for_pos(self, index): + def line_for_pos(self, index: int) -> int: @functools.total_ordering class LineComparator: - def __init__(self, pos): - self.pos = pos + def __init__(self, pos: _PosType): + self.pos: _PosType = pos def __lt__(self, other): return self.pos[1] < other @@ -221,13 +225,13 @@ def __eq__(self, other): try: line_pos = self.lines[line_index] except IndexError: - return None - if line_pos[0] <= index <= line_pos[1]: - return line_index - return None + raise IndexError(f"Position {index} is not in the string") + if not (line_pos[0] <= index <= line_pos[1]): + raise IndexError(f"Position {index} is inside a line separator") + return line_index def _calculate_lines(self): - self.lines = [] + self.lines: list[_PosType] = [] line_begin = 0 line_end = 0 @@ -259,9 +263,9 @@ def _calculate_lines(self): class ExecutionContext(contextlib.AbstractContextManager): - def __init__(self, args): - self.args = args - self.checks = [] + def __init__(self, args: argparse.Namespace): + self.args: argparse.Namespace = args + self.checks: list[Callable[[], None]] = [] def add_check(self, check): self.checks.append(check) @@ -305,11 +309,11 @@ class LintMain: context_class = ExecutionContext def __init__(self): - self.argparser = argparse.ArgumentParser() + self.argparser: argparse.ArgumentParser = argparse.ArgumentParser() self.argparser.add_argument( "--fix", action="store_true", help="automatically fix warnings" ) self.argparser.add_argument("files", nargs="+", metavar="file") - def execute(self): + def execute(self) -> ExecutionContext: return self.context_class(self.argparser.parse_args()) diff --git a/test/rapids_pre_commit_hooks/test_lint.py b/test/rapids_pre_commit_hooks/test_lint.py index 47599c5..117fe7f 100644 --- a/test/rapids_pre_commit_hooks/test_lint.py +++ b/test/rapids_pre_commit_hooks/test_lint.py @@ -28,12 +28,13 @@ class TestLinter: + LONG_CONTENTS = ( + "line 1\nline 2\rline 3\r\nline 4\r\n\nline 6\r\n\r\nline 8\n\r\n" + "line 10\r\r\nline 12\r\n\rline 14\n\nline 16\r\rline 18\n\rline 20" + ) + def test_lines(self): - linter = Linter( - "test.txt", - "line 1\nline 2\rline 3\r\nline 4\r\n\nline 6\r\n\r\nline 8\n\r\n" - + "line 10\r\r\nline 12\r\n\rline 14\n\nline 16\r\rline 18\n\rline 20", - ) + linter = Linter("test.txt", self.LONG_CONTENTS) assert linter.lines == [ (0, 6), (7, 13), @@ -74,26 +75,39 @@ def test_lines(self): (0, 0), ] - def test_line_for_pos(self): - linter = Linter( - "test.txt", - "line 1\nline 2\rline 3\r\nline 4\r\n\nline 6\r\n\r\nline 8\n\r\n" - + "line 10\r\r\nline 12\r\n\rline 14\n\nline 16\r\rline 18\n\rline 20", - ) - assert linter.line_for_pos(0) == 0 - assert linter.line_for_pos(3) == 0 - assert linter.line_for_pos(6) == 0 - assert linter.line_for_pos(10) == 1 - assert linter.line_for_pos(21) is None - assert linter.line_for_pos(34) == 5 - assert linter.line_for_pos(97) == 19 - assert linter.line_for_pos(104) == 19 - assert linter.line_for_pos(200) is None - - linter = Linter("test.txt", "line 1") - assert linter.line_for_pos(0) == 0 - assert linter.line_for_pos(3) == 0 - assert linter.line_for_pos(6) == 0 + @pytest.mark.parametrize( + ["contents", "pos", "line", "raises"], + [ + (LONG_CONTENTS, 0, 0, contextlib.nullcontext()), + (LONG_CONTENTS, 3, 0, contextlib.nullcontext()), + (LONG_CONTENTS, 6, 0, contextlib.nullcontext()), + (LONG_CONTENTS, 10, 1, contextlib.nullcontext()), + ( + LONG_CONTENTS, + 21, + None, + pytest.raises( + IndexError, match="^Position 21 is inside a line separator$" + ), + ), + (LONG_CONTENTS, 34, 5, contextlib.nullcontext()), + (LONG_CONTENTS, 97, 19, contextlib.nullcontext()), + (LONG_CONTENTS, 104, 19, contextlib.nullcontext()), + ( + LONG_CONTENTS, + 200, + None, + pytest.raises(IndexError, match="^Position 200 is not in the string$"), + ), + ("line 1", 0, 0, contextlib.nullcontext()), + ("line 1", 3, 0, contextlib.nullcontext()), + ("line 1", 6, 0, contextlib.nullcontext()), + ], + ) + def test_line_for_pos(self, contents, pos, line, raises): + linter = Linter("test.txt", contents) + with raises: + assert linter.line_for_pos(pos) == line def test_fix(self): linter = Linter("test.txt", "Hello world!") From 6336dcadb7ac17f6c0ed01b0ae76db3ea9e9ffa1 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 23 Aug 2024 15:08:19 -0400 Subject: [PATCH 02/20] Add type hints to rapids_pre_commit_hooks.alpha_spec --- src/rapids_pre_commit_hooks/alpha_spec.py | 100 ++++++++++++++++------ src/rapids_pre_commit_hooks/lint.py | 4 +- 2 files changed, 77 insertions(+), 27 deletions(-) diff --git a/src/rapids_pre_commit_hooks/alpha_spec.py b/src/rapids_pre_commit_hooks/alpha_spec.py index 78f056b..9dd1f5a 100644 --- a/src/rapids_pre_commit_hooks/alpha_spec.py +++ b/src/rapids_pre_commit_hooks/alpha_spec.py @@ -19,27 +19,27 @@ import yaml from packaging.requirements import InvalidRequirement, Requirement -from rapids_metadata.metadata import RAPIDSVersion +from rapids_metadata.metadata import RAPIDSMetadata, RAPIDSVersion from rapids_metadata.remote import fetch_latest -from .lint import LintMain +from .lint import Linter, LintMain -ALPHA_SPECIFIER = ">=0.0.0a0" +ALPHA_SPECIFIER: str = ">=0.0.0a0" -ALPHA_SPEC_OUTPUT_TYPES = { +ALPHA_SPEC_OUTPUT_TYPES: set[str] = { "pyproject", "requirements", } -CUDA_SUFFIX_REGEX = re.compile(r"^(?P.*)-cu[0-9]{2}$") +CUDA_SUFFIX_REGEX: re.Pattern = re.compile(r"^(?P.*)-cu[0-9]{2}$") @cache -def all_metadata(): +def all_metadata() -> RAPIDSMetadata: return fetch_latest() -def node_has_type(node, tag_type): +def node_has_type(node: yaml.Node, tag_type: str): return node.tag == f"tag:yaml.org,2002:{tag_type}" @@ -60,7 +60,9 @@ def strip_cuda_suffix(args: argparse.Namespace, name: str) -> str: return name -def check_and_mark_anchor(anchors, used_anchors, node): +def check_and_mark_anchor( + anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node +): for key, value in anchors.items(): if value == node: anchor = key @@ -74,16 +76,26 @@ def check_and_mark_anchor(anchors, used_anchors, node): return True, anchor -def check_package_spec(linter, args, anchors, used_anchors, node): +def check_package_spec( + linter: Linter, + args: argparse.Namespace, + anchors: dict[str, yaml.Node], + used_anchors: set[str], + node: yaml.Node, +): @total_ordering class SpecPriority: - def __init__(self, spec): - self.spec = spec + def __init__(self, spec: str): + self.spec: str = spec - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, SpecPriority): + return False return self.spec == other.spec - def __lt__(self, other): + def __lt__(self, other: object) -> bool: + if not isinstance(other, SpecPriority): + return False if self.spec == other.spec: return False if self.spec == ALPHA_SPECIFIER: @@ -92,10 +104,10 @@ def __lt__(self, other): return True return self.sort_str() < other.sort_str() - def sort_str(self): + def sort_str(self) -> str: return "".join(c for c in self.spec if c not in "<>=") - def create_specifier_string(specifiers): + def create_specifier_string(specifiers: set[str]) -> str: return ",".join(sorted(specifiers, key=SpecPriority)) if node_has_type(node, "str"): @@ -140,7 +152,13 @@ def create_specifier_string(specifiers): ) -def check_packages(linter, args, anchors, used_anchors, node): +def check_packages( + linter: Linter, + args: argparse.Namespace, + anchors: dict[str, yaml.Node], + used_anchors: set[str], + node: yaml.Node, +): if node_has_type(node, "seq"): descend, _ = check_and_mark_anchor(anchors, used_anchors, node) if descend: @@ -148,7 +166,13 @@ def check_packages(linter, args, anchors, used_anchors, node): check_package_spec(linter, args, anchors, used_anchors, package_spec) -def check_common(linter, args, anchors, used_anchors, node): +def check_common( + linter: Linter, + args: argparse.Namespace, + anchors: dict[str, yaml.Node], + used_anchors: set[str], + node: yaml.Node, +): if node_has_type(node, "seq"): for dependency_set in node.value: if node_has_type(dependency_set, "map"): @@ -162,7 +186,13 @@ def check_common(linter, args, anchors, used_anchors, node): ) -def check_matrices(linter, args, anchors, used_anchors, node): +def check_matrices( + linter: Linter, + args: argparse.Namespace, + anchors: dict[str, yaml.Node], + used_anchors: set[str], + node: yaml.Node, +): if node_has_type(node, "seq"): for item in node.value: if node_has_type(item, "map"): @@ -176,7 +206,13 @@ def check_matrices(linter, args, anchors, used_anchors, node): ) -def check_specific(linter, args, anchors, used_anchors, node): +def check_specific( + linter: Linter, + args: argparse.Namespace, + anchors: dict[str, yaml.Node], + used_anchors: set[str], + node: yaml.Node, +): if node_has_type(node, "seq"): for matrix_matcher in node.value: if node_has_type(matrix_matcher, "map"): @@ -190,7 +226,13 @@ def check_specific(linter, args, anchors, used_anchors, node): ) -def check_dependencies(linter, args, anchors, used_anchors, node): +def check_dependencies( + linter: Linter, + args: argparse.Namespace, + anchors: dict[str, yaml.Node], + used_anchors: set[str], + node: yaml.Node, +): if node_has_type(node, "map"): for _, dependencies_value in node.value: if node_has_type(dependencies_value, "map"): @@ -206,7 +248,13 @@ def check_dependencies(linter, args, anchors, used_anchors, node): ) -def check_root(linter, args, anchors, used_anchors, node): +def check_root( + linter: Linter, + args: argparse.Namespace, + anchors: dict[str, yaml.Node], + used_anchors: set[str], + node: yaml.Node, +): if node_has_type(node, "map"): for root_key, root_value in node.value: if node_has_type(root_key, "str") and root_key.value == "dependencies": @@ -221,27 +269,29 @@ class AnchorPreservingLoader(yaml.SafeLoader): def __init__(self, stream): super().__init__(stream) - self.document_anchors = [] + self.document_anchors: list[dict[str, yaml.Node]] = [] - def compose_document(self): + def compose_document(self) -> yaml.Node: # Drop the DOCUMENT-START event. self.get_event() # Compose the root node. - node = self.compose_node(None, None) + node = self.compose_node(None, None) # type: ignore # Drop the DOCUMENT-END event. self.get_event() self.document_anchors.append(self.anchors) self.anchors = {} + assert node is not None return node -def check_alpha_spec(linter, args): +def check_alpha_spec(linter: Linter, args: argparse.Namespace): loader = AnchorPreservingLoader(linter.content) try: root = loader.get_single_node() + assert root is not None finally: loader.dispose() check_root(linter, args, loader.document_anchors[0], set(), root) diff --git a/src/rapids_pre_commit_hooks/lint.py b/src/rapids_pre_commit_hooks/lint.py index f901782..c7a1cd8 100644 --- a/src/rapids_pre_commit_hooks/lint.py +++ b/src/rapids_pre_commit_hooks/lint.py @@ -265,9 +265,9 @@ def _calculate_lines(self): class ExecutionContext(contextlib.AbstractContextManager): def __init__(self, args: argparse.Namespace): self.args: argparse.Namespace = args - self.checks: list[Callable[[], None]] = [] + self.checks: list[Callable[[Linter, argparse.Namespace], None]] = [] - def add_check(self, check): + def add_check(self, check: Callable[[Linter, argparse.Namespace], None]): self.checks.append(check) def __exit__(self, exc_type, exc_value, traceback): From 90662734f5c5a1ad932bd75f64d40122b1e74900 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 23 Aug 2024 15:24:25 -0400 Subject: [PATCH 03/20] Add type hints to rapids_pre_commit_hooks.copyright --- src/rapids_pre_commit_hooks/copyright.py | 50 +++++++++++++++--------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 68aba9b..ac16f82 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -12,22 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import datetime import functools import os import re import warnings +from typing import Callable, Optional, Union import git -from .lint import LintMain +from .lint import Linter, LintMain -COPYRIGHT_RE = re.compile( +COPYRIGHT_RE: re.Pattern = re.compile( r"Copyright *(?:\(c\))? *(?P(?P\d{4})(-(?P\d{4}))?),?" r" *NVIDIA C(?:ORPORATION|orporation)" ) -BRANCH_RE = re.compile(r"^branch-(?P[0-9]+)\.(?P[0-9]+)$") -COPYRIGHT_REPLACEMENT = "Copyright (c) {first_year}-{last_year}, NVIDIA CORPORATION" +BRANCH_RE: re.Pattern = re.compile(r"^branch-(?P[0-9]+)\.(?P[0-9]+)$") +COPYRIGHT_REPLACEMENT: str = ( + "Copyright (c) {first_year}-{last_year}, NVIDIA CORPORATION" +) class NoTargetBranchWarning(RuntimeWarning): @@ -38,14 +42,14 @@ class ConflictingFilesWarning(RuntimeWarning): pass -def match_copyright(content): +def match_copyright(content: str) -> list[re.Match]: return list(COPYRIGHT_RE.finditer(content)) -def strip_copyright(content, copyright_matches): +def strip_copyright(content: str, copyright_matches: list[re.Match]) -> list[str]: lines = [] - def append_stripped(start, item): + def append_stripped(start: int, item: re.Match): lines.append(content[start : item.start()]) return item.end() @@ -54,7 +58,7 @@ def append_stripped(start, item): return lines -def apply_copyright_revert(linter, old_match, new_match): +def apply_copyright_revert(linter: Linter, old_match: re.Match, new_match: re.Match): if old_match.group("years") == new_match.group("years"): warning_pos = new_match.span() else: @@ -65,7 +69,7 @@ def apply_copyright_revert(linter, old_match, new_match): ).add_replacement(new_match.span(), old_match.group()) -def apply_copyright_update(linter, match, year): +def apply_copyright_update(linter: Linter, match: re.Match, year: int): linter.add_warning(match.span("years"), "copyright is out of date").add_replacement( match.span(), COPYRIGHT_REPLACEMENT.format( @@ -75,7 +79,7 @@ def apply_copyright_update(linter, match, year): ) -def apply_copyright_check(linter, old_content): +def apply_copyright_check(linter: Linter, old_content: str): if linter.content != old_content: current_year = datetime.datetime.now().year new_copyright_matches = match_copyright(linter.content) @@ -102,7 +106,7 @@ def apply_copyright_check(linter, old_content): linter.add_warning((0, 0), "no copyright notice found") -def get_target_branch(repo, args): +def get_target_branch(repo: git.Repo, args: argparse.Namespace) -> Optional[str]: """Determine which branch is the "target" branch. The target branch is determined in the following order: @@ -168,7 +172,9 @@ def get_target_branch(repo, args): return None -def get_target_branch_upstream_commit(repo, args): +def get_target_branch_upstream_commit( + repo: git.Repo, args: argparse.Namespace +) -> Optional[git.Commit]: # If no target branch can be determined, use HEAD if it exists target_branch_name = get_target_branch(repo, args) if target_branch_name is None: @@ -194,7 +200,7 @@ def get_target_branch_upstream_commit(repo, args): key=lambda commit: commit.committed_datetime, ) - def try_get_ref(remote): + def try_get_ref(remote: git.Remote) -> Optional[git.Reference]: try: return remote.refs[target_branch_name] except IndexError: @@ -222,7 +228,9 @@ def try_get_ref(remote): return None -def get_changed_files(args): +def get_changed_files( + args: argparse.Namespace, +) -> dict[Union[str, os.PathLike[str]], Optional[git.Blob]]: try: repo = git.Repo() except git.InvalidGitRepositoryError: @@ -232,7 +240,9 @@ def get_changed_files(args): for filename in filenames } - changed_files = {f: None for f in repo.untracked_files} + changed_files: dict[Union[str, os.PathLike[str]], Optional[git.Blob]] = { + f: None for f in repo.untracked_files + } target_branch_upstream_commit = get_target_branch_upstream_commit(repo, args) if target_branch_upstream_commit is None: changed_files.update({blob.path: None for _, blob in repo.index.iter_blobs()}) @@ -256,14 +266,14 @@ def get_changed_files(args): return changed_files -def normalize_git_filename(filename): +def normalize_git_filename(filename: Union[str, os.PathLike[str]]): relpath = os.path.relpath(filename) if re.search(r"^\.\.(/|$)", relpath): return None return relpath -def find_blob(tree, filename): +def find_blob(tree: git.Tree, filename: Union[str, os.PathLike[str]]): d1, d2 = os.path.split(filename) split = [d2] while d1: @@ -283,10 +293,12 @@ def find_blob(tree, filename): return None -def check_copyright(args): +def check_copyright( + args: argparse.Namespace, +) -> Callable[[Linter, argparse.Namespace], None]: changed_files = get_changed_files(args) - def the_check(linter, args): + def the_check(linter: Linter, args: argparse.Namespace): if not (git_filename := normalize_git_filename(linter.filename)): warnings.warn( f'File "{linter.filename}" is outside of current directory. Not ' From c471b96c8e387e64c008a8df26413c7c4ea95597 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 23 Aug 2024 15:36:16 -0400 Subject: [PATCH 04/20] Add type hints to rapids_pre_commit_hooks.pyproject_license --- .../pyproject_license.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/rapids_pre_commit_hooks/pyproject_license.py b/src/rapids_pre_commit_hooks/pyproject_license.py index ab119a9..849605e 100644 --- a/src/rapids_pre_commit_hooks/pyproject_license.py +++ b/src/rapids_pre_commit_hooks/pyproject_license.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import copy import uuid import tomlkit import tomlkit.exceptions -from .lint import LintMain +from .lint import Linter, LintMain RAPIDS_LICENSE = "Apache 2.0" ACCEPTABLE_LICENSES = { @@ -27,7 +28,12 @@ } -def find_value_location(document, key, append): +_LocType = tuple[int, int] + + +def find_value_location( + document: tomlkit.TOMLDocument, key: tuple[str, ...], append: bool +) -> _LocType: copied_document = copy.deepcopy(document) placeholder = uuid.uuid4() placeholder_toml = tomlkit.string(str(placeholder)) @@ -38,7 +44,7 @@ def find_value_location(document, key, append): # look for that in the new document. node = copied_document while len(key) > (0 if append else 1): - node = node[key[0]] + node = node[key[0]] # type: ignore key = key[1:] if append: node.add(str(placeholder), placeholder_toml) @@ -54,13 +60,13 @@ def find_value_location(document, key, append): return begin_loc, end_loc -def check_pyproject_license(linter, args): +def check_pyproject_license(linter: Linter, args: argparse.Namespace): document = tomlkit.loads(linter.content) try: add_project_table = True project_table = document["project"] - add_project_table = project_table.is_super_table() - license_value = project_table["license"]["text"] + add_project_table = project_table.is_super_table() # type: ignore + license_value = project_table["license"]["text"] # type: ignore except tomlkit.exceptions.NonExistentKey: if add_project_table: loc = (len(linter.content), len(linter.content)) From f1447e53efe0231dfce1d2ee187f8f3daf21c841 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 23 Aug 2024 15:38:27 -0400 Subject: [PATCH 05/20] Add type hints to rapids_pre_commit_hooks.shell --- src/rapids_pre_commit_hooks/shell/__init__.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/rapids_pre_commit_hooks/shell/__init__.py b/src/rapids_pre_commit_hooks/shell/__init__.py index 15b407d..366240e 100644 --- a/src/rapids_pre_commit_hooks/shell/__init__.py +++ b/src/rapids_pre_commit_hooks/shell/__init__.py @@ -12,30 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse + import bashlex -from ..lint import ExecutionContext, LintMain +from ..lint import ExecutionContext, Linter, LintMain + +_PosType = tuple[int, int] class LintVisitor(bashlex.ast.nodevisitor): - def __init__(self, linter, args): - self.linter = linter - self.args = args + def __init__(self, linter: Linter, args: argparse.Namespace): + self.linter: Linter = linter + self.args: argparse.Namespace = args - def add_warning(self, pos, msg): + def add_warning(self, pos: _PosType, msg: str): return self.linter.add_warning(pos, msg) class ShellExecutionContext(ExecutionContext): - def __init__(self, args): + def __init__(self, args: argparse.Namespace): super().__init__(args) - self.visitors = [] + self.visitors: list[type] = [] self.add_check(self.check_shell) - def add_visitor_class(self, cls): + def add_visitor_class(self, cls: type): self.visitors.append(cls) - def check_shell(self, linter, args): + def check_shell(self, linter: Linter, args: argparse.Namespace): parts = bashlex.parse(linter.content) for cls in self.visitors: From dff9f469a673f440b57bfbdd8d71f51ed7f44c41 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 23 Aug 2024 15:58:25 -0400 Subject: [PATCH 06/20] Add type hints to test/rapids_pre_commit_hooks/test_alpha_spec.py --- .../test_alpha_spec.py | 74 +++++++++++-------- 1 file changed, 42 insertions(+), 32 deletions(-) diff --git a/test/rapids_pre_commit_hooks/test_alpha_spec.py b/test/rapids_pre_commit_hooks/test_alpha_spec.py index b958c8d..715bd7e 100644 --- a/test/rapids_pre_commit_hooks/test_alpha_spec.py +++ b/test/rapids_pre_commit_hooks/test_alpha_spec.py @@ -16,6 +16,7 @@ import os.path from itertools import chain from textwrap import dedent +from typing import Iterator, Optional from unittest.mock import MagicMock, Mock, call, patch import pytest @@ -31,7 +32,7 @@ @contextlib.contextmanager -def set_cwd(cwd): +def set_cwd(cwd: os.PathLike[str]) -> Iterator: old_cwd = os.getcwd() os.chdir(cwd) try: @@ -41,17 +42,23 @@ def set_cwd(cwd): @pytest.mark.parametrize( - ["version_file", "version_arg", "expected_version"], + ["version_file", "version_arg", "expected_version", "raises"], [ - ("24.06", None, "24.06"), - ("24.06", "24.08", "24.08"), - ("24.08", "24.06", "24.06"), - (None, "24.06", "24.06"), - (None, "24.10", KeyError), - (None, None, FileNotFoundError), + ("24.06", None, "24.06", contextlib.nullcontext()), + ("24.06", "24.08", "24.08", contextlib.nullcontext()), + ("24.08", "24.06", "24.06", contextlib.nullcontext()), + (None, "24.06", "24.06", contextlib.nullcontext()), + (None, "24.10", None, pytest.raises(KeyError)), + (None, None, None, pytest.raises(FileNotFoundError)), ], ) -def test_get_rapids_version(tmp_path, version_file, version_arg, expected_version): +def test_get_rapids_version( + tmp_path: os.PathLike, + version_file: Optional[str], + version_arg: Optional[str], + expected_version: Optional[str], + raises: contextlib.AbstractContextManager, +): MOCK_METADATA = RAPIDSMetadata( versions={ "24.06": RAPIDSVersion( @@ -74,16 +81,10 @@ def test_get_rapids_version(tmp_path, version_file, version_arg, expected_versio with open("VERSION", "w") as f: f.write(f"{version_file}\n") args = Mock(rapids_version=version_arg) - if isinstance(expected_version, type) and issubclass( - expected_version, BaseException - ): - with pytest.raises(expected_version): - alpha_spec.get_rapids_version(args) - else: - assert ( - alpha_spec.get_rapids_version(args) - == MOCK_METADATA.versions[expected_version] - ) + with raises: + version = alpha_spec.get_rapids_version(args) + if expected_version: + assert version == MOCK_METADATA.versions[expected_version] def test_anchor_preserving_loader(): @@ -127,8 +128,8 @@ def test_anchor_preserving_loader(): "rapids_pre_commit_hooks.alpha_spec.get_rapids_version", Mock(return_value=latest_metadata), ) -def test_strip_cuda_suffix(name, stripped_name): - assert alpha_spec.strip_cuda_suffix(None, name) == stripped_name +def test_strip_cuda_suffix(name: str, stripped_name: str): + assert alpha_spec.strip_cuda_suffix(Mock(), name) == stripped_name @pytest.mark.parametrize( @@ -172,10 +173,14 @@ def test_strip_cuda_suffix(name, stripped_name): ], ) def test_check_and_mark_anchor( - used_anchors_before, node_index, descend, anchor, used_anchors_after + used_anchors_before: set[str], + node_index: int, + descend: bool, + anchor: Optional[str], + used_anchors_after: set[str], ): NODES = [Mock() for _ in range(3)] - ANCHORS = { + ANCHORS: dict[str, yaml.Node] = { "anchor1": NODES[0], "anchor2": NODES[1], } @@ -246,12 +251,13 @@ def test_check_and_mark_anchor( "rapids_pre_commit_hooks.alpha_spec.get_rapids_version", Mock(return_value=latest_metadata), ) -def test_check_package_spec(package, content, mode, replacement): +def test_check_package_spec(package: str, content: str, mode: str, replacement: str): args = Mock(mode=mode) linter = lint.Linter("dependencies.yaml", content) loader = alpha_spec.AnchorPreservingLoader(content) try: composed = loader.get_single_node() + assert composed is not None finally: loader.dispose() alpha_spec.check_package_spec( @@ -349,7 +355,7 @@ def test_check_package_spec_anchor(): ), ], ) -def test_check_packages(content, indices, use_anchor): +def test_check_packages(content: str, indices: list[int], use_anchor: bool): with patch( "rapids_pre_commit_hooks.alpha_spec.check_package_spec", Mock() ) as mock_check_package_spec: @@ -357,7 +363,7 @@ def test_check_packages(content, indices, use_anchor): linter = lint.Linter("dependencies.yaml", content) composed = yaml.compose(content) anchors = {"anchor": composed} - used_anchors = set() + used_anchors: set[str] = set() alpha_spec.check_packages(linter, args, anchors, used_anchors, composed) assert used_anchors == ({"anchor"} if use_anchor else set()) alpha_spec.check_packages(linter, args, anchors, used_anchors, composed) @@ -387,7 +393,7 @@ def test_check_packages(content, indices, use_anchor): ), ], ) -def test_check_common(content, indices): +def test_check_common(content: str, indices: list[tuple[int, int]]): with patch( "rapids_pre_commit_hooks.alpha_spec.check_packages", Mock() ) as mock_check_packages: @@ -422,7 +428,7 @@ def test_check_common(content, indices): ), ], ) -def test_check_matrices(content, indices): +def test_check_matrices(content: str, indices: list[tuple[int, int]]): with patch( "rapids_pre_commit_hooks.alpha_spec.check_packages", Mock() ) as mock_check_packages: @@ -468,7 +474,7 @@ def test_check_matrices(content, indices): ), ], ) -def test_check_specific(content, indices): +def test_check_specific(content: str, indices: list[tuple[int, int]]): with patch( "rapids_pre_commit_hooks.alpha_spec.check_matrices", Mock() ) as mock_check_matrices: @@ -521,7 +527,11 @@ def test_check_specific(content, indices): ), ], ) -def test_check_dependencies(content, common_indices, specific_indices): +def test_check_dependencies( + content: str, + common_indices: list[tuple[int, int]], + specific_indices: list[tuple[int, int]], +): with patch( "rapids_pre_commit_hooks.alpha_spec.check_common", Mock() ) as mock_check_common, patch( @@ -558,7 +568,7 @@ def test_check_dependencies(content, common_indices, specific_indices): ), ], ) -def test_check_root(content, indices): +def test_check_root(content: str, indices: list[int]): with patch( "rapids_pre_commit_hooks.alpha_spec.check_dependencies", Mock() ) as mock_check_dependencies: @@ -593,7 +603,7 @@ def test_check_alpha_spec(): ) -def test_check_alpha_spec_integration(tmp_path): +def test_check_alpha_spec_integration(tmp_path: os.PathLike[str]): CONTENT = dedent( """\ dependencies: From 04e82822db30efe65e3290b6f1d2ae5857f1b20c Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 23 Aug 2024 16:35:38 -0400 Subject: [PATCH 07/20] Add type hints to test/rapids_pre_commit_hooks/test_copyright.py --- .../rapids_pre_commit_hooks/test_copyright.py | 253 ++++++++++-------- 1 file changed, 141 insertions(+), 112 deletions(-) diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 50e118b..cbb0ed1 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import datetime import os.path import tempfile +from io import BufferedReader +from textwrap import dedent +from typing import Any, TextIO, Union from unittest.mock import Mock, patch import git @@ -26,11 +30,13 @@ def test_match_copyright(): - CONTENT = r""" -Copyright (c) 2024 NVIDIA CORPORATION -Copyright (c) 2021-2024 NVIDIA CORPORATION -# Copyright 2021, NVIDIA Corporation and affiliates -""" + CONTENT = dedent( + r""" + Copyright (c) 2024 NVIDIA CORPORATION + Copyright (c) 2021-2024 NVIDIA CORPORATION + # Copyright 2021, NVIDIA Corporation and affiliates + """ + ) re_matches = copyright.match_copyright(CONTENT) matches = [ @@ -65,15 +71,17 @@ def test_match_copyright(): def test_strip_copyright(): - CONTENT = r""" -This is a line before the first copyright statement -Copyright (c) 2024 NVIDIA CORPORATION -This is a line between the first two copyright statements -Copyright (c) 2021-2024 NVIDIA CORPORATION -This is a line between the next two copyright statements -# Copyright 2021, NVIDIA Corporation and affiliates -This is a line after the last copyright statement -""" + CONTENT = dedent( + r""" + This is a line before the first copyright statement + Copyright (c) 2024 NVIDIA CORPORATION + This is a line between the first two copyright statements + Copyright (c) 2021-2024 NVIDIA CORPORATION + This is a line between the next two copyright statements + # Copyright 2021, NVIDIA Corporation and affiliates + This is a line after the last copyright statement + """ + ) matches = copyright.match_copyright(CONTENT) stripped = copyright.strip_copyright(CONTENT, matches) assert stripped == [ @@ -89,7 +97,7 @@ def test_strip_copyright(): @freeze_time("2024-01-18") def test_apply_copyright_check(): - def run_apply_copyright_check(old_content, new_content): + def run_apply_copyright_check(old_content: str, new_content: str): linter = Linter("file.txt", new_content) copyright.apply_copyright_check(linter, old_content) return linter @@ -103,23 +111,27 @@ def run_apply_copyright_check(old_content, new_content): linter = run_apply_copyright_check("No copyright notice", "No copyright notice") assert linter.warnings == [] - OLD_CONTENT = r""" -Copyright (c) 2021-2023 NVIDIA CORPORATION -Copyright (c) 2023 NVIDIA CORPORATION -Copyright (c) 2024 NVIDIA CORPORATION -Copyright (c) 2025 NVIDIA CORPORATION -This file has not been changed -""" + OLD_CONTENT = dedent( + r""" + Copyright (c) 2021-2023 NVIDIA CORPORATION + Copyright (c) 2023 NVIDIA CORPORATION + Copyright (c) 2024 NVIDIA CORPORATION + Copyright (c) 2025 NVIDIA CORPORATION + This file has not been changed + """ + ) linter = run_apply_copyright_check(OLD_CONTENT, OLD_CONTENT) assert linter.warnings == [] - NEW_CONTENT = r""" -Copyright (c) 2021-2023 NVIDIA CORPORATION -Copyright (c) 2023 NVIDIA CORPORATION -Copyright (c) 2024 NVIDIA CORPORATION -Copyright (c) 2025 NVIDIA CORPORATION -This file has been changed -""" + NEW_CONTENT = dedent( + r""" + Copyright (c) 2021-2023 NVIDIA CORPORATION + Copyright (c) 2023 NVIDIA CORPORATION + Copyright (c) 2024 NVIDIA CORPORATION + Copyright (c) 2025 NVIDIA CORPORATION + This file has been changed + """ + ) expected_linter = Linter("file.txt", NEW_CONTENT) expected_linter.add_warning((15, 24), "copyright is out of date").add_replacement( (1, 43), "Copyright (c) 2021-2024, NVIDIA CORPORATION" @@ -142,13 +154,15 @@ def run_apply_copyright_check(old_content, new_content): linter = run_apply_copyright_check(None, NEW_CONTENT) assert linter.warnings == expected_linter.warnings - NEW_CONTENT = r""" -Copyright (c) 2021-2024 NVIDIA CORPORATION -Copyright (c) 2023 NVIDIA CORPORATION -Copyright (c) 2024 NVIDIA CORPORATION -Copyright (c) 2025 NVIDIA Corporation -This file has not been changed -""" + NEW_CONTENT = dedent( + r""" + Copyright (c) 2021-2024 NVIDIA CORPORATION + Copyright (c) 2023 NVIDIA CORPORATION + Copyright (c) 2024 NVIDIA CORPORATION + Copyright (c) 2025 NVIDIA Corporation + This file has not been changed + """ + ) expected_linter = Linter("file.txt", NEW_CONTENT) expected_linter.add_warning( (15, 24), "copyright is not out of date and should not be updated" @@ -162,16 +176,17 @@ def run_apply_copyright_check(old_content, new_content): @pytest.fixture -def git_repo(): - with tempfile.TemporaryDirectory() as d: - repo = git.Repo.init(d) - with repo.config_writer() as w: - w.set_value("user", "name", "RAPIDS Test Fixtures") - w.set_value("user", "email", "testfixtures@rapids.ai") - yield repo +def git_repo(tmp_path: os.PathLike[str]) -> git.Repo: + repo = git.Repo.init(tmp_path) + with repo.config_writer() as w: + w.set_value("user", "name", "RAPIDS Test Fixtures") + w.set_value("user", "email", "testfixtures@rapids.ai") + return repo + +def test_get_target_branch(git_repo: git.Repo): + assert git_repo.working_tree_dir is not None -def test_get_target_branch(git_repo): with patch.dict("os.environ", {}, clear=True): args = Mock(main_branch=None, target_branch=None) @@ -254,15 +269,16 @@ def test_get_target_branch(git_repo): assert copyright.get_target_branch(git_repo, args) == "master" -def test_get_target_branch_upstream_commit(git_repo): - def fn(repo, filename): +def test_get_target_branch_upstream_commit(git_repo: git.Repo): + def fn(repo: git.Repo, filename: str) -> str: + assert repo.working_tree_dir is not None return os.path.join(repo.working_tree_dir, filename) - def write_file(repo, filename, contents): + def write_file(repo: git.Repo, filename: str, contents: str): with open(fn(repo, filename), "w") as f: f.write(contents) - def mock_target_branch(branch): + def mock_target_branch(branch: Any): return patch( "rapids_pre_commit_hooks.copyright.get_target_branch", Mock(return_value=branch), @@ -300,7 +316,7 @@ def mock_target_branch(branch): remote_1_branch_1 = remote_repo_1.create_head( "branch-1-renamed", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_1 + remote_repo_1.head.reference = remote_1_branch_1 # type: ignore remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file1.txt", "File 1 modified") remote_repo_1.index.add(["file1.txt"]) @@ -312,7 +328,7 @@ def mock_target_branch(branch): remote_1_branch_2 = remote_repo_1.create_head( "branch-2", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_2 + remote_repo_1.head.reference = remote_1_branch_2 # type: ignore remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file2.txt", "File 2 modified") remote_repo_1.index.add(["file2.txt"]) @@ -321,7 +337,7 @@ def mock_target_branch(branch): remote_1_branch_3 = remote_repo_1.create_head( "branch-3", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_3 + remote_repo_1.head.reference = remote_1_branch_3 # type: ignore remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file3.txt", "File 3 modified") remote_repo_1.index.add(["file3.txt"]) @@ -333,7 +349,7 @@ def mock_target_branch(branch): remote_1_branch_4 = remote_repo_1.create_head( "branch-4", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_4 + remote_repo_1.head.reference = remote_1_branch_4 # type: ignore remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file4.txt", "File 4 modified") remote_repo_1.index.add(["file4.txt"]) @@ -345,7 +361,7 @@ def mock_target_branch(branch): remote_1_branch_7 = remote_repo_1.create_head( "branch-7", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_7 + remote_repo_1.head.reference = remote_1_branch_7 # type: ignore remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file7.txt", "File 7 modified") remote_repo_1.index.add(["file7.txt"]) @@ -361,7 +377,7 @@ def mock_target_branch(branch): remote_2_branch_3 = remote_repo_2.create_head( "branch-3", remote_2_master.commit ) - remote_repo_2.head.reference = remote_2_branch_3 + remote_repo_2.head.reference = remote_2_branch_3 # type: ignore remote_repo_2.head.reset(index=True, working_tree=True) write_file(remote_repo_2, "file3.txt", "File 3 modified") remote_repo_2.index.add(["file3.txt"]) @@ -373,7 +389,7 @@ def mock_target_branch(branch): remote_2_branch_4 = remote_repo_2.create_head( "branch-4", remote_2_master.commit ) - remote_repo_2.head.reference = remote_2_branch_4 + remote_repo_2.head.reference = remote_2_branch_4 # type: ignore remote_repo_2.head.reset(index=True, working_tree=True) write_file(remote_repo_2, "file4.txt", "File 4 modified") remote_repo_2.index.add(["file4.txt"]) @@ -385,17 +401,17 @@ def mock_target_branch(branch): remote_2_branch_5 = remote_repo_2.create_head( "branch-5", remote_2_master.commit ) - remote_repo_2.head.reference = remote_2_branch_5 + remote_repo_2.head.reference = remote_2_branch_5 # type: ignore remote_repo_2.head.reset(index=True, working_tree=True) write_file(remote_repo_2, "file5.txt", "File 5 modified") remote_repo_2.index.add(["file5.txt"]) remote_repo_2.index.commit("Update file5.txt") with mock_target_branch(None): - assert copyright.get_target_branch_upstream_commit(git_repo, None) is None + assert copyright.get_target_branch_upstream_commit(git_repo, Mock()) is None with mock_target_branch("branch-1"): - assert copyright.get_target_branch_upstream_commit(git_repo, None) is None + assert copyright.get_target_branch_upstream_commit(git_repo, Mock()) is None remote_1 = git_repo.create_remote("unconventional/remote/name/1", remote_dir_1) remote_1.fetch([ @@ -415,7 +431,7 @@ def mock_target_branch(branch): with branch_1.config_writer() as w: w.set_value("remote", "unconventional/remote/name/1") w.set_value("merge", "branch-1-renamed") - git_repo.head.reference = branch_1 + git_repo.head.reference = branch_1 # type: ignore git_repo.head.reset(index=True, working_tree=True) git_repo.index.remove("file1.txt", working_tree=True) git_repo.index.commit( @@ -424,7 +440,7 @@ def mock_target_branch(branch): ) branch_6 = git_repo.create_head("branch-6", remote_1.refs["master"]) - git_repo.head.reference = branch_6 + git_repo.head.reference = branch_6 # type: ignore git_repo.head.reset(index=True, working_tree=True) git_repo.index.remove(["file6.txt"], working_tree=True) git_repo.index.commit("Remove file6.txt") @@ -433,7 +449,7 @@ def mock_target_branch(branch): with branch_7.config_writer() as w: w.set_value("remote", "unconventional/remote/name/1") w.set_value("merge", "branch-7") - git_repo.head.reference = branch_7 + git_repo.head.reference = branch_7 # type: ignore git_repo.head.reset(index=True, working_tree=True) git_repo.index.remove(["file7.txt"], working_tree=True) git_repo.index.commit( @@ -441,66 +457,70 @@ def mock_target_branch(branch): commit_date=datetime.datetime(2024, 2, 1, tzinfo=datetime.timezone.utc), ) - git_repo.head.reference = main + git_repo.head.reference = main # type: ignore git_repo.head.reset(index=True, working_tree=True) with mock_target_branch("branch-1"): assert ( - copyright.get_target_branch_upstream_commit(git_repo, None) + copyright.get_target_branch_upstream_commit(git_repo, Mock()) == remote_1.refs["branch-1-renamed"].commit ) with mock_target_branch("branch-2"): assert ( - copyright.get_target_branch_upstream_commit(git_repo, None) + copyright.get_target_branch_upstream_commit(git_repo, Mock()) == remote_1.refs["branch-2"].commit ) with mock_target_branch("branch-3"): assert ( - copyright.get_target_branch_upstream_commit(git_repo, None) + copyright.get_target_branch_upstream_commit(git_repo, Mock()) == remote_1.refs["branch-3"].commit ) with mock_target_branch("branch-4"): assert ( - copyright.get_target_branch_upstream_commit(git_repo, None) + copyright.get_target_branch_upstream_commit(git_repo, Mock()) == remote_2.refs["branch-4"].commit ) with mock_target_branch("branch-5"): assert ( - copyright.get_target_branch_upstream_commit(git_repo, None) + copyright.get_target_branch_upstream_commit(git_repo, Mock()) == remote_2.refs["branch-5"].commit ) with mock_target_branch("branch-6"): assert ( - copyright.get_target_branch_upstream_commit(git_repo, None) + copyright.get_target_branch_upstream_commit(git_repo, Mock()) == branch_6.commit ) with mock_target_branch("branch-7"): assert ( - copyright.get_target_branch_upstream_commit(git_repo, None) + copyright.get_target_branch_upstream_commit(git_repo, Mock()) == branch_7.commit ) with mock_target_branch("nonexistent-branch"): assert ( - copyright.get_target_branch_upstream_commit(git_repo, None) + copyright.get_target_branch_upstream_commit(git_repo, Mock()) == main.commit ) with mock_target_branch(None): assert ( - copyright.get_target_branch_upstream_commit(git_repo, None) + copyright.get_target_branch_upstream_commit(git_repo, Mock()) == main.commit ) -def test_get_changed_files(git_repo): - def mock_os_walk(top): +def test_get_changed_files(git_repo: git.Repo): + f: Union[BufferedReader, TextIO] + + assert git_repo.working_tree_dir is not None + + def mock_os_walk(top: Union[str, os.PathLike[str]]): return patch( "os.walk", Mock( @@ -526,19 +546,20 @@ def mock_os_walk(top): os.mkdir(os.path.join(non_git_dir, "subdir1/subdir2")) with open(os.path.join(non_git_dir, "subdir1", "subdir2", "sub.txt"), "w") as f: f.write("Subdir file\n") - assert copyright.get_changed_files(None) == { + assert copyright.get_changed_files(Mock()) == { "top.txt": None, "subdir1/subdir2/sub.txt": None, } - def fn(filename): + def fn(filename: str) -> str: + assert git_repo.working_tree_dir is not None return os.path.join(git_repo.working_tree_dir, filename) - def write_file(filename, contents): + def write_file(filename: str, contents: str): with open(fn(filename), "w") as f: f.write(contents) - def file_contents(verbed): + def file_contents(verbed: str) -> str: return f"This file will be {verbed}\n" * 100 write_file("untouched.txt", file_contents("untouched")) @@ -571,7 +592,7 @@ def file_contents(verbed): "rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit", Mock(return_value=None), ): - assert copyright.get_changed_files(None) == { + assert copyright.get_changed_files(Mock()) == { "untouched.txt": None, "copied.txt": None, "modified_and_copied.txt": None, @@ -591,7 +612,7 @@ def file_contents(verbed): git_repo.index.commit("Remove modified.txt") pr_branch = git_repo.create_head("pr", "HEAD~") - git_repo.head.reference = pr_branch + git_repo.head.reference = pr_branch # type: ignore git_repo.head.reset(index=True, working_tree=True) write_file("copied_2.txt", file_contents("copied")) @@ -639,7 +660,7 @@ def file_contents(verbed): target_branch = git_repo.heads["master"] merge_base = git_repo.merge_base(target_branch, "HEAD")[0] old_files = { - blob.path: blob + blob.path: blob # type: ignore for blob in merge_base.tree.traverse(lambda b, _: isinstance(b, git.Blob)) } @@ -665,7 +686,7 @@ def file_contents(verbed): "rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit", Mock(return_value=target_branch.commit), ): - changed_files = copyright.get_changed_files(None) + changed_files = copyright.get_changed_files(Mock()) assert { path: old_blob.path if old_blob else None for path, old_blob in changed_files.items() @@ -675,24 +696,25 @@ def file_contents(verbed): if old: with open(fn(new), "rb") as f: new_contents = f.read() - old_contents = old_files[old].data_stream.read() + old_contents = old_files[old].data_stream.read() # type: ignore assert new_contents != old_contents - assert changed_files[new].data_stream.read() == old_contents + assert changed_files[new].data_stream.read() == old_contents # type: ignore for new, old in superfluous.items(): if old: with open(fn(new), "rb") as f: new_contents = f.read() - old_contents = old_files[old].data_stream.read() + old_contents = old_files[old].data_stream.read() # type: ignore assert new_contents == old_contents - assert changed_files[new].data_stream.read() == old_contents + assert changed_files[new].data_stream.read() == old_contents # type: ignore -def test_get_changed_files_multiple_merge_bases(git_repo): - def fn(filename): +def test_get_changed_files_multiple_merge_bases(git_repo: git.Repo): + def fn(filename: str) -> str: + assert git_repo.working_tree_dir is not None return os.path.join(git_repo.working_tree_dir, filename) - def write_file(filename, contents): + def write_file(filename: str, contents: str): with open(fn(filename), "w") as f: f.write(contents) @@ -703,7 +725,7 @@ def write_file(filename, contents): git_repo.index.commit("Initial commit") branch_1 = git_repo.create_head("branch-1", "master") - git_repo.head.reference = branch_1 + git_repo.head.reference = branch_1 # type: ignore git_repo.index.reset(index=True, working_tree=True) write_file("file1.txt", "File 1 modified\n") git_repo.index.add("file1.txt") @@ -713,7 +735,7 @@ def write_file(filename, contents): ) branch_2 = git_repo.create_head("branch-2", "master") - git_repo.head.reference = branch_2 + git_repo.head.reference = branch_2 # type: ignore git_repo.index.reset(index=True, working_tree=True) write_file("file2.txt", "File 2 modified\n") git_repo.index.add("file2.txt") @@ -723,7 +745,7 @@ def write_file(filename, contents): ) branch_1_2 = git_repo.create_head("branch-1-2", "master") - git_repo.head.reference = branch_1_2 + git_repo.head.reference = branch_1_2 # type: ignore git_repo.index.reset(index=True, working_tree=True) write_file("file1.txt", "File 1 modified\n") write_file("file2.txt", "File 2 modified\n") @@ -735,7 +757,7 @@ def write_file(filename, contents): ) branch_3 = git_repo.create_head("branch-3", "master") - git_repo.head.reference = branch_3 + git_repo.head.reference = branch_3 # type: ignore git_repo.index.reset(index=True, working_tree=True) write_file("file1.txt", "File 1 modified\n") write_file("file2.txt", "File 2 modified\n") @@ -756,7 +778,7 @@ def write_file(filename, contents): "rapids_pre_commit_hooks.copyright.get_target_branch", Mock(return_value="branch-1-2"), ): - changed_files = copyright.get_changed_files(None) + changed_files = copyright.get_changed_files(Mock()) assert { path: old_blob.path if old_blob else None for path, old_blob in changed_files.items() @@ -786,7 +808,9 @@ def test_normalize_git_filename(): ) -def test_find_blob(git_repo): +def test_find_blob(git_repo: git.Repo): + assert git_repo.working_tree_dir is not None + with open(os.path.join(git_repo.working_tree_dir, "top.txt"), "w"): pass os.mkdir(os.path.join(git_repo.working_tree_dir, "sub1")) @@ -806,25 +830,30 @@ def test_find_blob(git_repo): @freeze_time("2024-01-18") -def test_check_copyright(git_repo): - def fn(filename): +def test_check_copyright(git_repo: git.Repo): + def fn(filename: str) -> str: + assert git_repo.working_tree_dir is not None return os.path.join(git_repo.working_tree_dir, filename) - def write_file(filename, contents): + def write_file(filename: str, contents: str): with open(fn(filename), "w") as f: f.write(contents) - def file_contents(num): - return rf""" -Copyright (c) 2021-2023 NVIDIA CORPORATION -File {num} -""" + def file_contents(num: int) -> str: + return dedent( + rf"""\ + Copyright (c) 2021-2023 NVIDIA CORPORATION + File {num} + """ + ) - def file_contents_modified(num): - return rf""" -Copyright (c) 2021-2023 NVIDIA CORPORATION -File {num} modified -""" + def file_contents_modified(num: int) -> str: + return dedent( + rf"""\ + Copyright (c) 2021-2023 NVIDIA CORPORATION + File {num} modified + """ + ) write_file("file1.txt", file_contents(1)) write_file("file2.txt", file_contents(2)) @@ -834,21 +863,21 @@ def file_contents_modified(num): git_repo.index.commit("Initial commit") branch_1 = git_repo.create_head("branch-1", "master") - git_repo.head.reference = branch_1 + git_repo.head.reference = branch_1 # type: ignore git_repo.head.reset(index=True, working_tree=True) write_file("file1.txt", file_contents_modified(1)) git_repo.index.add(["file1.txt"]) git_repo.index.commit("Update file1.txt") branch_2 = git_repo.create_head("branch-2", "master") - git_repo.head.reference = branch_2 + git_repo.head.reference = branch_2 # type: ignore git_repo.head.reset(index=True, working_tree=True) write_file("file2.txt", file_contents_modified(2)) git_repo.index.add(["file2.txt"]) git_repo.index.commit("Update file2.txt") pr = git_repo.create_head("pr", "branch-1") - git_repo.head.reference = pr + git_repo.head.reference = pr # type: ignore git_repo.head.reset(index=True, working_tree=True) write_file("file3.txt", file_contents_modified(3)) git_repo.index.add(["file3.txt"]) @@ -864,8 +893,8 @@ def file_contents_modified(num): def mock_repo_cwd(): return patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)) - def mock_target_branch_upstream_commit(target_branch): - def func(repo, args): + def mock_target_branch_upstream_commit(target_branch: str): + def func(repo: git.Repo, args: argparse.Namespace) -> git.Commit: assert target_branch == args.target_branch return repo.heads[target_branch].commit From 3f96086f75b1b5e15dafcb3b93cd374e45091111 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 23 Aug 2024 16:49:15 -0400 Subject: [PATCH 08/20] Add type hints to test/rapids_pre_commit_hooks/test_lint.py --- test/rapids_pre_commit_hooks/test_lint.py | 65 ++++++++++++----------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/test/rapids_pre_commit_hooks/test_lint.py b/test/rapids_pre_commit_hooks/test_lint.py index 117fe7f..bf63c8c 100644 --- a/test/rapids_pre_commit_hooks/test_lint.py +++ b/test/rapids_pre_commit_hooks/test_lint.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import contextlib import os.path -import tempfile +from typing import BinaryIO, Generator, TextIO from unittest.mock import Mock, call, patch import pytest @@ -104,7 +105,13 @@ def test_lines(self): ("line 1", 6, 0, contextlib.nullcontext()), ], ) - def test_line_for_pos(self, contents, pos, line, raises): + def test_line_for_pos( + self, + contents: str, + pos: int, + line: int, + raises: contextlib.AbstractContextManager, + ): linter = Linter("test.txt", contents) with raises: assert linter.line_for_pos(pos) == line @@ -135,42 +142,40 @@ def test_fix(self): class TestLintMain: @pytest.fixture - def hello_world_file(self): - with tempfile.NamedTemporaryFile("w+") as f: + def hello_world_file(self, tmp_path: str) -> Generator[TextIO, None, None]: + with open(os.path.join(tmp_path, "hello_world.txt"), "w+") as f: f.write("Hello world!") f.flush() f.seek(0) yield f @pytest.fixture - def hello_file(self): - with tempfile.NamedTemporaryFile("w+") as f: + def hello_file(self, tmp_path: str) -> Generator[TextIO, None, None]: + with open(os.path.join(tmp_path, "hello.txt"), "w+") as f: f.write("Hello!") f.flush() f.seek(0) yield f @pytest.fixture - def binary_file(self): - with tempfile.NamedTemporaryFile("wb+") as f: + def binary_file(self, tmp_path: str) -> Generator[BinaryIO, None, None]: + with open(os.path.join(tmp_path, "binary.bin"), "wb+") as f: f.write(b"\xDE\xAD\xBE\xEF") f.flush() f.seek(0) yield f @pytest.fixture - def long_file(self): - with tempfile.NamedTemporaryFile("w+") as f: + def long_file(self, tmp_path: str) -> Generator[TextIO, None, None]: + with open(os.path.join(tmp_path, "long.txt"), "w+") as f: f.write("This is a long file\nIt has multiple lines\n") f.flush() f.seek(0) yield f @pytest.fixture - def bracket_file(self): - with tempfile.TemporaryDirectory() as d, open( - os.path.join(d, "file[with]brackets.txt"), "w+" - ) as f: + def bracket_file(self, tmp_path: str) -> Generator[TextIO, None, None]: + with open(os.path.join(tmp_path, "file[with]brackets.txt"), "w+") as f: f.write("This [file] [has] [brackets]\n") f.flush() f.seek(0) @@ -184,7 +189,7 @@ def mock_console(self): ): yield m - def the_check(self, linter, args): + def the_check(self, linter: Linter, args: argparse.Namespace): assert args.check_test linter.add_warning((0, 5), "say good bye instead").add_replacement( (0, 5), "Good bye" @@ -192,25 +197,25 @@ def the_check(self, linter, args): if linter.content[5] != "!": linter.add_warning((5, 5), "use punctuation").add_replacement((5, 5), ",") - def long_file_check(self, linter, args): + def long_file_check(self, linter: Linter, args: argparse.Namespace): linter.add_warning((0, len(linter.content)), "this is a long file") - def long_fix_check(self, linter, args): + def long_fix_check(self, linter: Linter, args: argparse.Namespace): linter.add_warning((0, 19), "this is a long line").add_replacement( (0, 19), "This is a long file\nIt's even longer now" ) - def long_delete_fix_check(self, linter, args): + def long_delete_fix_check(self, linter: Linter, args: argparse.Namespace): linter.add_warning( (0, len(linter.content)), "this is a long file" ).add_replacement((0, len(linter.content)), "This is a short file now") - def bracket_check(self, linter, args): + def bracket_check(self, linter: Linter, args: argparse.Namespace): linter.add_warning((0, 28), "this [file] has brackets").add_replacement( (12, 17), "[has more]" ) - def test_no_warnings_no_fix(self, hello_world_file): + def test_no_warnings_no_fix(self, hello_world_file: TextIO): with patch( "sys.argv", ["check-test", "--check-test", hello_world_file.name] ), self.mock_console() as console: @@ -223,7 +228,7 @@ def test_no_warnings_no_fix(self, hello_world_file): call(highlight=False), ] - def test_no_warnings_fix(self, hello_world_file): + def test_no_warnings_fix(self, hello_world_file: TextIO): with patch( "sys.argv", ["check-test", "--check-test", "--fix", hello_world_file.name] ), self.mock_console() as console: @@ -236,7 +241,7 @@ def test_no_warnings_fix(self, hello_world_file): call(highlight=False), ] - def test_warnings_no_fix(self, hello_world_file): + def test_warnings_no_fix(self, hello_world_file: TextIO): with patch( "sys.argv", ["check-test", "--check-test", hello_world_file.name] ), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"): @@ -267,7 +272,7 @@ def test_warnings_no_fix(self, hello_world_file): call().print(), ] - def test_warnings_fix(self, hello_world_file): + def test_warnings_fix(self, hello_world_file: TextIO): with patch( "sys.argv", ["check-test", "--check-test", "--fix", hello_world_file.name] ), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"): @@ -298,7 +303,7 @@ def test_warnings_fix(self, hello_world_file): call().print(), ] - def test_multiple_files(self, hello_world_file, hello_file): + def test_multiple_files(self, hello_world_file: TextIO, hello_file: TextIO): with patch( "sys.argv", [ @@ -347,7 +352,7 @@ def test_multiple_files(self, hello_world_file, hello_file): call().print(), ] - def test_binary_file(self, binary_file): + def test_binary_file(self, binary_file: BinaryIO): mock_linter = Mock(wraps=Linter) with patch( "sys.argv", @@ -367,7 +372,7 @@ def test_binary_file(self, binary_file): ctx.add_check(self.the_check) mock_linter.assert_not_called() - def test_long_file(self, long_file): + def test_long_file(self, long_file: TextIO): with patch( "sys.argv", [ @@ -405,7 +410,7 @@ def test_long_file(self, long_file): call().print(), ] - def test_long_file_delete(self, long_file): + def test_long_file_delete(self, long_file: TextIO): with patch( "sys.argv", [ @@ -438,7 +443,7 @@ def test_long_file_delete(self, long_file): call().print(), ] - def test_long_file_fix(self, long_file): + def test_long_file_fix(self, long_file: TextIO): with patch( "sys.argv", [ @@ -477,7 +482,7 @@ def test_long_file_fix(self, long_file): call().print(), ] - def test_long_file_delete_fix(self, long_file): + def test_long_file_delete_fix(self, long_file: TextIO): with patch( "sys.argv", [ @@ -505,7 +510,7 @@ def test_long_file_delete_fix(self, long_file): call().print(), ] - def test_bracket_file(self, bracket_file): + def test_bracket_file(self, bracket_file: TextIO): with patch( "sys.argv", [ From 1af8f78d04ad61f94319d9351726c8f8db91987d Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 23 Aug 2024 16:52:30 -0400 Subject: [PATCH 09/20] Add type hints to test/rapids_pre_commit_hooks/test_pyproject_license.py --- .../test_pyproject_license.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test/rapids_pre_commit_hooks/test_pyproject_license.py b/test/rapids_pre_commit_hooks/test_pyproject_license.py index 9bb4c7c..6120ebc 100644 --- a/test/rapids_pre_commit_hooks/test_pyproject_license.py +++ b/test/rapids_pre_commit_hooks/test_pyproject_license.py @@ -13,6 +13,7 @@ # limitations under the License. from textwrap import dedent +from unittest.mock import Mock import pytest import tomlkit @@ -20,6 +21,8 @@ from rapids_pre_commit_hooks import pyproject_license from rapids_pre_commit_hooks.lint import Linter +_LocType = tuple[int, int] + @pytest.mark.parametrize( ["key", "append", "loc"], @@ -51,7 +54,7 @@ ), ], ) -def test_find_value_location(key, append, loc): +def test_find_value_location(key: tuple[str, ...], append: bool, loc: _LocType): CONTENT = dedent( """\ [table] @@ -179,10 +182,14 @@ def test_find_value_location(key, append, loc): ], ) def test_check_pyproject_license( - document, loc, message, replacement_loc, replacement_text + document: str, + loc: _LocType, + message: str, + replacement_loc: _LocType, + replacement_text: str, ): linter = Linter("pyproject.toml", document) - pyproject_license.check_pyproject_license(linter, None) + pyproject_license.check_pyproject_license(linter, Mock()) expected_linter = Linter("pyproject.toml", document) if loc and message: From 5c9aa5bd31bf2c52df95254382e8185bc7fb2cb7 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 23 Aug 2024 16:53:07 -0400 Subject: [PATCH 10/20] Add type hints to test/rapids_pre_commit_hooks/test_shell.py --- test/rapids_pre_commit_hooks/test_shell.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rapids_pre_commit_hooks/test_shell.py b/test/rapids_pre_commit_hooks/test_shell.py index f5ddbf2..1d2996b 100644 --- a/test/rapids_pre_commit_hooks/test_shell.py +++ b/test/rapids_pre_commit_hooks/test_shell.py @@ -18,7 +18,7 @@ from rapids_pre_commit_hooks.shell.verify_conda_yes import VerifyCondaYesVisitor -def run_shell_linter(content, cls): +def run_shell_linter(content: str, cls: type) -> Linter: linter = Linter("test.sh", content) visitor = cls(linter, None) parts = bashlex.parse(content) From 556322e640d4ccfbfea5c76fa5dec76b7fa6d32f Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 23 Aug 2024 17:01:32 -0400 Subject: [PATCH 11/20] Add type hints to test/test_pre_commit.py --- test/test_pre_commit.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/test/test_pre_commit.py b/test/test_pre_commit.py index 8a6eee5..9809e55 100644 --- a/test/test_pre_commit.py +++ b/test/test_pre_commit.py @@ -19,11 +19,13 @@ import subprocess import sys from functools import cache +from typing import Generator, Optional, Union import git import pytest import yaml from packaging.version import Version +from rapids_metadata.metadata import RAPIDSMetadata from rapids_metadata.remote import fetch_latest REPO_DIR = os.path.join(os.path.dirname(__file__), "..") @@ -33,12 +35,12 @@ @cache -def all_metadata(): +def all_metadata() -> RAPIDSMetadata: return fetch_latest() @contextlib.contextmanager -def set_cwd(cwd): +def set_cwd(cwd: Union[str, os.PathLike[str]]) -> Generator: old_cwd = os.getcwd() os.chdir(cwd) try: @@ -48,7 +50,7 @@ def set_cwd(cwd): @pytest.fixture -def git_repo(tmp_path): +def git_repo(tmp_path: str) -> git.Repo: repo = git.Repo.init(tmp_path) with repo.config_writer() as w: w.set_value("user", "name", "RAPIDS Test Fixtures") @@ -56,8 +58,12 @@ def git_repo(tmp_path): return repo -def run_pre_commit(git_repo, hook_name, expected_status, exc): - def list_files(top): +def run_pre_commit( + git_repo: git.Repo, hook_name: str, expected_status: str, exc: Optional[type] +): + assert git_repo.working_tree_dir is not None + + def list_files(top: str) -> Generator[str, None, None]: for dirpath, _, filenames in os.walk(top): for filename in filenames: yield filename if top == dirpath else os.path.join( @@ -72,7 +78,7 @@ def list_files(top): f.write(f"{max(all_metadata().versions.keys(), key=Version)}\n") git_repo.index.add("VERSION") - git_repo.index.add(list_files(master_dir)) + git_repo.index.add(list(list_files(master_dir))) git_repo.index.commit( "Initial commit", commit_date=datetime.datetime(2023, 2, 1, tzinfo=datetime.timezone.utc), @@ -80,10 +86,12 @@ def list_files(top): branch_dir = os.path.join(example_dir, "branch") if os.path.exists(branch_dir): - git_repo.head.reference = git_repo.create_head("branch", git_repo.head.commit) - git_repo.index.remove(list_files(master_dir), working_tree=True) + git_repo.head.reference = git_repo.create_head( # type: ignore + "branch", git_repo.head.commit + ) + git_repo.index.remove(list(list_files(master_dir)), working_tree=True) shutil.copytree(branch_dir, git_repo.working_tree_dir, dirs_exist_ok=True) - git_repo.index.add(list_files(branch_dir)) + git_repo.index.add(list(list_files(branch_dir))) git_repo.index.commit( "Make some changes", commit_date=datetime.datetime(2024, 2, 1, tzinfo=datetime.timezone.utc), @@ -102,7 +110,7 @@ def list_files(top): "hook_name", ALL_HOOKS, ) -def test_pre_commit_pass(git_repo, hook_name): +def test_pre_commit_pass(git_repo: git.Repo, hook_name: str): run_pre_commit(git_repo, hook_name, "pass", None) @@ -110,5 +118,5 @@ def test_pre_commit_pass(git_repo, hook_name): "hook_name", ALL_HOOKS, ) -def test_pre_commit_fail(git_repo, hook_name): +def test_pre_commit_fail(git_repo: git.Repo, hook_name: str): run_pre_commit(git_repo, hook_name, "fail", subprocess.CalledProcessError) From 93000b10e51950114cc0f7fde8d26d9500acca66 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 26 Aug 2024 11:49:14 -0400 Subject: [PATCH 12/20] Add more type hints, refactor test --- src/rapids_pre_commit_hooks/alpha_spec.py | 25 ++++--- src/rapids_pre_commit_hooks/copyright.py | 16 ++-- src/rapids_pre_commit_hooks/lint.py | 28 +++---- .../pyproject_license.py | 8 +- src/rapids_pre_commit_hooks/shell/__init__.py | 12 +-- .../shell/verify_conda_yes.py | 4 +- .../test_alpha_spec.py | 34 +++++---- .../rapids_pre_commit_hooks/test_copyright.py | 48 +++++++----- test/rapids_pre_commit_hooks/test_lint.py | 75 ++++++++++--------- .../test_pyproject_license.py | 4 +- test/rapids_pre_commit_hooks/test_shell.py | 2 +- test/test_pre_commit.py | 6 +- 12 files changed, 142 insertions(+), 120 deletions(-) diff --git a/src/rapids_pre_commit_hooks/alpha_spec.py b/src/rapids_pre_commit_hooks/alpha_spec.py index 9dd1f5a..4d83467 100644 --- a/src/rapids_pre_commit_hooks/alpha_spec.py +++ b/src/rapids_pre_commit_hooks/alpha_spec.py @@ -16,6 +16,7 @@ import os import re from functools import cache, total_ordering +from typing import Optional import yaml from packaging.requirements import InvalidRequirement, Requirement @@ -39,7 +40,7 @@ def all_metadata() -> RAPIDSMetadata: return fetch_latest() -def node_has_type(node: yaml.Node, tag_type: str): +def node_has_type(node: yaml.Node, tag_type: str) -> bool: return node.tag == f"tag:yaml.org,2002:{tag_type}" @@ -62,7 +63,7 @@ def strip_cuda_suffix(args: argparse.Namespace, name: str) -> str: def check_and_mark_anchor( anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node -): +) -> tuple[bool, Optional[str]]: for key, value in anchors.items(): if value == node: anchor = key @@ -82,7 +83,7 @@ def check_package_spec( anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node, -): +) -> None: @total_ordering class SpecPriority: def __init__(self, spec: str): @@ -158,7 +159,7 @@ def check_packages( anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node, -): +) -> None: if node_has_type(node, "seq"): descend, _ = check_and_mark_anchor(anchors, used_anchors, node) if descend: @@ -172,7 +173,7 @@ def check_common( anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node, -): +) -> None: if node_has_type(node, "seq"): for dependency_set in node.value: if node_has_type(dependency_set, "map"): @@ -192,7 +193,7 @@ def check_matrices( anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node, -): +) -> None: if node_has_type(node, "seq"): for item in node.value: if node_has_type(item, "map"): @@ -212,7 +213,7 @@ def check_specific( anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node, -): +) -> None: if node_has_type(node, "seq"): for matrix_matcher in node.value: if node_has_type(matrix_matcher, "map"): @@ -232,7 +233,7 @@ def check_dependencies( anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node, -): +) -> None: if node_has_type(node, "map"): for _, dependencies_value in node.value: if node_has_type(dependencies_value, "map"): @@ -254,7 +255,7 @@ def check_root( anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node, -): +) -> None: if node_has_type(node, "map"): for root_key, root_value in node.value: if node_has_type(root_key, "str") and root_key.value == "dependencies": @@ -267,7 +268,7 @@ class AnchorPreservingLoader(yaml.SafeLoader): dictionary for each parsed document. """ - def __init__(self, stream): + def __init__(self, stream) -> None: super().__init__(stream) self.document_anchors: list[dict[str, yaml.Node]] = [] @@ -287,7 +288,7 @@ def compose_document(self) -> yaml.Node: return node -def check_alpha_spec(linter: Linter, args: argparse.Namespace): +def check_alpha_spec(linter: Linter, args: argparse.Namespace) -> None: loader = AnchorPreservingLoader(linter.content) try: root = loader.get_single_node() @@ -297,7 +298,7 @@ def check_alpha_spec(linter: Linter, args: argparse.Namespace): check_root(linter, args, loader.document_anchors[0], set(), root) -def main(): +def main() -> None: m = LintMain() m.argparser.description = ( "Verify that RAPIDS packages in dependencies.yaml do (or do not) have " diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index ac16f82..8f236c4 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -58,7 +58,9 @@ def append_stripped(start: int, item: re.Match): return lines -def apply_copyright_revert(linter: Linter, old_match: re.Match, new_match: re.Match): +def apply_copyright_revert( + linter: Linter, old_match: re.Match, new_match: re.Match +) -> None: if old_match.group("years") == new_match.group("years"): warning_pos = new_match.span() else: @@ -69,7 +71,7 @@ def apply_copyright_revert(linter: Linter, old_match: re.Match, new_match: re.Ma ).add_replacement(new_match.span(), old_match.group()) -def apply_copyright_update(linter: Linter, match: re.Match, year: int): +def apply_copyright_update(linter: Linter, match: re.Match, year: int) -> None: linter.add_warning(match.span("years"), "copyright is out of date").add_replacement( match.span(), COPYRIGHT_REPLACEMENT.format( @@ -79,7 +81,7 @@ def apply_copyright_update(linter: Linter, match: re.Match, year: int): ) -def apply_copyright_check(linter: Linter, old_content: str): +def apply_copyright_check(linter: Linter, old_content: Optional[str]) -> None: if linter.content != old_content: current_year = datetime.datetime.now().year new_copyright_matches = match_copyright(linter.content) @@ -266,14 +268,16 @@ def get_changed_files( return changed_files -def normalize_git_filename(filename: Union[str, os.PathLike[str]]): +def normalize_git_filename(filename: Union[str, os.PathLike[str]]) -> Optional[str]: relpath = os.path.relpath(filename) if re.search(r"^\.\.(/|$)", relpath): return None return relpath -def find_blob(tree: git.Tree, filename: Union[str, os.PathLike[str]]): +def find_blob( + tree: git.Tree, filename: Union[str, os.PathLike[str]] +) -> Optional[git.Blob]: d1, d2 = os.path.split(filename) split = [d2] while d1: @@ -322,7 +326,7 @@ def the_check(linter: Linter, args: argparse.Namespace): return the_check -def main(): +def main() -> None: m = LintMain() m.argparser.description = ( "Verify that all files have had their copyright notices updated. Each file " diff --git a/src/rapids_pre_commit_hooks/lint.py b/src/rapids_pre_commit_hooks/lint.py index c7a1cd8..e98851d 100644 --- a/src/rapids_pre_commit_hooks/lint.py +++ b/src/rapids_pre_commit_hooks/lint.py @@ -48,7 +48,7 @@ class BinaryFileWarning(Warning): class Replacement: - def __init__(self, pos: _PosType, newtext: str): + def __init__(self, pos: _PosType, newtext: str) -> None: self.pos: _PosType = pos self.newtext: str = newtext @@ -62,12 +62,12 @@ def __repr__(self) -> str: class LintWarning: - def __init__(self, pos: _PosType, msg: str): + def __init__(self, pos: _PosType, msg: str) -> None: self.pos: _PosType = pos self.msg: str = msg self.replacements: list[Replacement] = [] - def add_replacement(self, pos: _PosType, newtext: str): + def add_replacement(self, pos: _PosType, newtext: str) -> None: self.replacements.append(Replacement(pos, newtext)) def __eq__(self, other: object) -> bool: @@ -89,16 +89,16 @@ def __repr__(self) -> str: class Linter: - NEWLINE_RE = re.compile("[\r\n]") + NEWLINE_RE: re.Pattern = re.compile("[\r\n]") - def __init__(self, filename: str, content: str): + def __init__(self, filename: str, content: str) -> None: self.filename: str = filename self.content: str = content self.warnings: list[LintWarning] = [] self.console: Console = Console(highlight=False) self._calculate_lines() - def add_warning(self, pos: _PosType, msg: str): + def add_warning(self, pos: _PosType, msg: str) -> LintWarning: w = LintWarning(pos, msg) self.warnings.append(w) return w @@ -127,7 +127,7 @@ def fix(self) -> str: replaced_content += self.content[cursor:] return replaced_content - def print_warnings(self, fix_applied: bool = False): + def print_warnings(self, fix_applied: bool = False) -> None: sorted_warnings = sorted(self.warnings, key=lambda warning: warning.pos) for warning in sorted_warnings: @@ -176,7 +176,9 @@ def print_warnings(self, fix_applied: bool = False): self.console.print("[bold]note:[/bold] suggested fix") self.console.print() - def print_highlighted_code(self, pos: _PosType, replacement: Optional[str] = None): + def print_highlighted_code( + self, pos: _PosType, replacement: Optional[str] = None + ) -> None: line_index = self.line_for_pos(pos[0]) line_pos = self.lines[line_index] left = pos[0] @@ -230,7 +232,7 @@ def __eq__(self, other): raise IndexError(f"Position {index} is inside a line separator") return line_index - def _calculate_lines(self): + def _calculate_lines(self) -> None: self.lines: list[_PosType] = [] line_begin = 0 @@ -263,14 +265,14 @@ def _calculate_lines(self): class ExecutionContext(contextlib.AbstractContextManager): - def __init__(self, args: argparse.Namespace): + def __init__(self, args: argparse.Namespace) -> None: self.args: argparse.Namespace = args self.checks: list[Callable[[Linter, argparse.Namespace], None]] = [] - def add_check(self, check: Callable[[Linter, argparse.Namespace], None]): + def add_check(self, check: Callable[[Linter, argparse.Namespace], None]) -> None: self.checks.append(check) - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: if exc_type: return @@ -308,7 +310,7 @@ def __exit__(self, exc_type, exc_value, traceback): class LintMain: context_class = ExecutionContext - def __init__(self): + def __init__(self) -> None: self.argparser: argparse.ArgumentParser = argparse.ArgumentParser() self.argparser.add_argument( "--fix", action="store_true", help="automatically fix warnings" diff --git a/src/rapids_pre_commit_hooks/pyproject_license.py b/src/rapids_pre_commit_hooks/pyproject_license.py index 849605e..d8622cf 100644 --- a/src/rapids_pre_commit_hooks/pyproject_license.py +++ b/src/rapids_pre_commit_hooks/pyproject_license.py @@ -21,8 +21,8 @@ from .lint import Linter, LintMain -RAPIDS_LICENSE = "Apache 2.0" -ACCEPTABLE_LICENSES = { +RAPIDS_LICENSE: str = "Apache 2.0" +ACCEPTABLE_LICENSES: set[str] = { RAPIDS_LICENSE, "BSD-3-Clause", } @@ -60,7 +60,7 @@ def find_value_location( return begin_loc, end_loc -def check_pyproject_license(linter: Linter, args: argparse.Namespace): +def check_pyproject_license(linter: Linter, args: argparse.Namespace) -> None: document = tomlkit.loads(linter.content) try: add_project_table = True @@ -93,7 +93,7 @@ def check_pyproject_license(linter: Linter, args: argparse.Namespace): linter.add_warning(loc, f'license should be "{RAPIDS_LICENSE}"') -def main(): +def main() -> None: m = LintMain() m.argparser.description = ( f'Verify that pyproject.toml has the correct license ("{RAPIDS_LICENSE}").' diff --git a/src/rapids_pre_commit_hooks/shell/__init__.py b/src/rapids_pre_commit_hooks/shell/__init__.py index 366240e..386da01 100644 --- a/src/rapids_pre_commit_hooks/shell/__init__.py +++ b/src/rapids_pre_commit_hooks/shell/__init__.py @@ -16,30 +16,30 @@ import bashlex -from ..lint import ExecutionContext, Linter, LintMain +from ..lint import ExecutionContext, Linter, LintMain, LintWarning _PosType = tuple[int, int] class LintVisitor(bashlex.ast.nodevisitor): - def __init__(self, linter: Linter, args: argparse.Namespace): + def __init__(self, linter: Linter, args: argparse.Namespace) -> None: self.linter: Linter = linter self.args: argparse.Namespace = args - def add_warning(self, pos: _PosType, msg: str): + def add_warning(self, pos: _PosType, msg: str) -> LintWarning: return self.linter.add_warning(pos, msg) class ShellExecutionContext(ExecutionContext): - def __init__(self, args: argparse.Namespace): + def __init__(self, args: argparse.Namespace) -> None: super().__init__(args) self.visitors: list[type] = [] self.add_check(self.check_shell) - def add_visitor_class(self, cls: type): + def add_visitor_class(self, cls: type) -> None: self.visitors.append(cls) - def check_shell(self, linter: Linter, args: argparse.Namespace): + def check_shell(self, linter: Linter, args: argparse.Namespace) -> None: parts = bashlex.parse(linter.content) for cls in self.visitors: diff --git a/src/rapids_pre_commit_hooks/shell/verify_conda_yes.py b/src/rapids_pre_commit_hooks/shell/verify_conda_yes.py index cf6c175..8c24116 100644 --- a/src/rapids_pre_commit_hooks/shell/verify_conda_yes.py +++ b/src/rapids_pre_commit_hooks/shell/verify_conda_yes.py @@ -40,7 +40,7 @@ class VerifyCondaYesVisitor(LintVisitor): - def visitcommand(self, n, parts): + def visitcommand(self, n, parts) -> None: part_words = [part.word for part in parts] if part_words[0] != "conda": return @@ -73,7 +73,7 @@ def visitcommand(self, n, parts): warning.add_replacement(insert_pos, f" {command['args'][0]}") -def main(): +def main() -> None: m = ShellMain() with m.execute() as ctx: ctx.add_visitor_class(VerifyCondaYesVisitor) diff --git a/test/rapids_pre_commit_hooks/test_alpha_spec.py b/test/rapids_pre_commit_hooks/test_alpha_spec.py index 715bd7e..9412004 100644 --- a/test/rapids_pre_commit_hooks/test_alpha_spec.py +++ b/test/rapids_pre_commit_hooks/test_alpha_spec.py @@ -58,7 +58,7 @@ def test_get_rapids_version( version_arg: Optional[str], expected_version: Optional[str], raises: contextlib.AbstractContextManager, -): +) -> None: MOCK_METADATA = RAPIDSMetadata( versions={ "24.06": RAPIDSVersion( @@ -87,10 +87,11 @@ def test_get_rapids_version( assert version == MOCK_METADATA.versions[expected_version] -def test_anchor_preserving_loader(): +def test_anchor_preserving_loader() -> None: loader = alpha_spec.AnchorPreservingLoader("- &a A\n- *a") try: root = loader.get_single_node() + assert root is not None finally: loader.dispose() assert loader.document_anchors == [{"a": root.value[0]}] @@ -128,7 +129,7 @@ def test_anchor_preserving_loader(): "rapids_pre_commit_hooks.alpha_spec.get_rapids_version", Mock(return_value=latest_metadata), ) -def test_strip_cuda_suffix(name: str, stripped_name: str): +def test_strip_cuda_suffix(name: str, stripped_name: str) -> None: assert alpha_spec.strip_cuda_suffix(Mock(), name) == stripped_name @@ -178,7 +179,7 @@ def test_check_and_mark_anchor( descend: bool, anchor: Optional[str], used_anchors_after: set[str], -): +) -> None: NODES = [Mock() for _ in range(3)] ANCHORS: dict[str, yaml.Node] = { "anchor1": NODES[0], @@ -251,7 +252,9 @@ def test_check_and_mark_anchor( "rapids_pre_commit_hooks.alpha_spec.get_rapids_version", Mock(return_value=latest_metadata), ) -def test_check_package_spec(package: str, content: str, mode: str, replacement: str): +def test_check_package_spec( + package: str, content: str, mode: str, replacement: str +) -> None: args = Mock(mode=mode) linter = lint.Linter("dependencies.yaml", content) loader = alpha_spec.AnchorPreservingLoader(content) @@ -281,7 +284,7 @@ def test_check_package_spec(package: str, content: str, mode: str, replacement: "rapids_pre_commit_hooks.alpha_spec.get_rapids_version", Mock(return_value=latest_metadata), ) -def test_check_package_spec_anchor(): +def test_check_package_spec_anchor() -> None: CONTENT = dedent( """\ - &cudf cudf>=24.04,<24.06 @@ -295,9 +298,10 @@ def test_check_package_spec_anchor(): loader = alpha_spec.AnchorPreservingLoader(CONTENT) try: composed = loader.get_single_node() + assert composed is not None finally: loader.dispose() - used_anchors = set() + used_anchors: set[str] = set() expected_linter = lint.Linter("dependencies.yaml", CONTENT) expected_linter.add_warning( @@ -355,7 +359,7 @@ def test_check_package_spec_anchor(): ), ], ) -def test_check_packages(content: str, indices: list[int], use_anchor: bool): +def test_check_packages(content: str, indices: list[int], use_anchor: bool) -> None: with patch( "rapids_pre_commit_hooks.alpha_spec.check_package_spec", Mock() ) as mock_check_package_spec: @@ -393,7 +397,7 @@ def test_check_packages(content: str, indices: list[int], use_anchor: bool): ), ], ) -def test_check_common(content: str, indices: list[tuple[int, int]]): +def test_check_common(content: str, indices: list[tuple[int, int]]) -> None: with patch( "rapids_pre_commit_hooks.alpha_spec.check_packages", Mock() ) as mock_check_packages: @@ -428,7 +432,7 @@ def test_check_common(content: str, indices: list[tuple[int, int]]): ), ], ) -def test_check_matrices(content: str, indices: list[tuple[int, int]]): +def test_check_matrices(content: str, indices: list[tuple[int, int]]) -> None: with patch( "rapids_pre_commit_hooks.alpha_spec.check_packages", Mock() ) as mock_check_packages: @@ -474,7 +478,7 @@ def test_check_matrices(content: str, indices: list[tuple[int, int]]): ), ], ) -def test_check_specific(content: str, indices: list[tuple[int, int]]): +def test_check_specific(content: str, indices: list[tuple[int, int]]) -> None: with patch( "rapids_pre_commit_hooks.alpha_spec.check_matrices", Mock() ) as mock_check_matrices: @@ -531,7 +535,7 @@ def test_check_dependencies( content: str, common_indices: list[tuple[int, int]], specific_indices: list[tuple[int, int]], -): +) -> None: with patch( "rapids_pre_commit_hooks.alpha_spec.check_common", Mock() ) as mock_check_common, patch( @@ -568,7 +572,7 @@ def test_check_dependencies( ), ], ) -def test_check_root(content: str, indices: list[int]): +def test_check_root(content: str, indices: list[int]) -> None: with patch( "rapids_pre_commit_hooks.alpha_spec.check_dependencies", Mock() ) as mock_check_dependencies: @@ -583,7 +587,7 @@ def test_check_root(content: str, indices: list[int]): ] -def test_check_alpha_spec(): +def test_check_alpha_spec() -> None: CONTENT = "dependencies: []" with patch( "rapids_pre_commit_hooks.alpha_spec.check_root", Mock() @@ -603,7 +607,7 @@ def test_check_alpha_spec(): ) -def test_check_alpha_spec_integration(tmp_path: os.PathLike[str]): +def test_check_alpha_spec_integration(tmp_path: os.PathLike[str]) -> None: CONTENT = dedent( """\ dependencies: diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index cbb0ed1..be095b0 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -18,7 +18,7 @@ import tempfile from io import BufferedReader from textwrap import dedent -from typing import Any, TextIO, Union +from typing import Any, Optional, TextIO, Union from unittest.mock import Mock, patch import git @@ -29,7 +29,7 @@ from rapids_pre_commit_hooks.lint import Linter -def test_match_copyright(): +def test_match_copyright() -> None: CONTENT = dedent( r""" Copyright (c) 2024 NVIDIA CORPORATION @@ -70,7 +70,7 @@ def test_match_copyright(): ] -def test_strip_copyright(): +def test_strip_copyright() -> None: CONTENT = dedent( r""" This is a line before the first copyright statement @@ -96,8 +96,10 @@ def test_strip_copyright(): @freeze_time("2024-01-18") -def test_apply_copyright_check(): - def run_apply_copyright_check(old_content: str, new_content: str): +def test_apply_copyright_check() -> None: + def run_apply_copyright_check( + old_content: Optional[str], new_content: str + ) -> Linter: linter = Linter("file.txt", new_content) copyright.apply_copyright_check(linter, old_content) return linter @@ -184,7 +186,7 @@ def git_repo(tmp_path: os.PathLike[str]) -> git.Repo: return repo -def test_get_target_branch(git_repo: git.Repo): +def test_get_target_branch(git_repo: git.Repo) -> None: assert git_repo.working_tree_dir is not None with patch.dict("os.environ", {}, clear=True): @@ -269,7 +271,7 @@ def test_get_target_branch(git_repo: git.Repo): assert copyright.get_target_branch(git_repo, args) == "master" -def test_get_target_branch_upstream_commit(git_repo: git.Repo): +def test_get_target_branch_upstream_commit(git_repo: git.Repo) -> None: def fn(repo: git.Repo, filename: str) -> str: assert repo.working_tree_dir is not None return os.path.join(repo.working_tree_dir, filename) @@ -515,7 +517,7 @@ def mock_target_branch(branch: Any): ) -def test_get_changed_files(git_repo: git.Repo): +def test_get_changed_files(git_repo: git.Repo) -> None: f: Union[BufferedReader, TextIO] assert git_repo.working_tree_dir is not None @@ -709,7 +711,7 @@ def file_contents(verbed: str) -> str: assert changed_files[new].data_stream.read() == old_contents # type: ignore -def test_get_changed_files_multiple_merge_bases(git_repo: git.Repo): +def test_get_changed_files_multiple_merge_bases(git_repo: git.Repo) -> None: def fn(filename: str) -> str: assert git_repo.working_tree_dir is not None return os.path.join(git_repo.working_tree_dir, filename) @@ -789,7 +791,7 @@ def write_file(filename: str, contents: str): } -def test_normalize_git_filename(): +def test_normalize_git_filename() -> None: assert copyright.normalize_git_filename("file.txt") == "file.txt" assert copyright.normalize_git_filename("sub/file.txt") == "sub/file.txt" assert copyright.normalize_git_filename("sub//file.txt") == "sub/file.txt" @@ -808,7 +810,16 @@ def test_normalize_git_filename(): ) -def test_find_blob(git_repo: git.Repo): +@pytest.mark.parametrize( + ["path", "present"], + [ + ("top.txt", True), + ("sub1/sub2/sub.txt", True), + ("nonexistent.txt", False), + ("nonexistent/sub.txt", False), + ], +) +def test_find_blob(git_repo: git.Repo, path: str, present: bool) -> None: assert git_repo.working_tree_dir is not None with open(os.path.join(git_repo.working_tree_dir, "top.txt"), "w"): @@ -820,17 +831,16 @@ def test_find_blob(git_repo: git.Repo): git_repo.index.add(["top.txt", "sub1/sub2/sub.txt"]) git_repo.index.commit("Initial commit") - assert copyright.find_blob(git_repo.head.commit.tree, "top.txt").path == "top.txt" - assert ( - copyright.find_blob(git_repo.head.commit.tree, "sub1/sub2/sub.txt").path - == "sub1/sub2/sub.txt" - ) - assert copyright.find_blob(git_repo.head.commit.tree, "nonexistent.txt") is None - assert copyright.find_blob(git_repo.head.commit.tree, "nonexistent/sub.txt") is None + blob = copyright.find_blob(git_repo.head.commit.tree, path) + if present: + assert blob is not None + assert blob.path == path + else: + assert blob is None @freeze_time("2024-01-18") -def test_check_copyright(git_repo: git.Repo): +def test_check_copyright(git_repo: git.Repo) -> None: def fn(filename: str) -> str: assert git_repo.working_tree_dir is not None return os.path.join(git_repo.working_tree_dir, filename) diff --git a/test/rapids_pre_commit_hooks/test_lint.py b/test/rapids_pre_commit_hooks/test_lint.py index bf63c8c..53687a6 100644 --- a/test/rapids_pre_commit_hooks/test_lint.py +++ b/test/rapids_pre_commit_hooks/test_lint.py @@ -15,6 +15,7 @@ import argparse import contextlib import os.path +from textwrap import dedent from typing import BinaryIO, Generator, TextIO from unittest.mock import Mock, call, patch @@ -29,12 +30,12 @@ class TestLinter: - LONG_CONTENTS = ( + LONG_CONTENTS: str = ( "line 1\nline 2\rline 3\r\nline 4\r\n\nline 6\r\n\r\nline 8\n\r\n" "line 10\r\r\nline 12\r\n\rline 14\n\nline 16\r\rline 18\n\rline 20" ) - def test_lines(self): + def test_lines(self) -> None: linter = Linter("test.txt", self.LONG_CONTENTS) assert linter.lines == [ (0, 6), @@ -111,12 +112,12 @@ def test_line_for_pos( pos: int, line: int, raises: contextlib.AbstractContextManager, - ): + ) -> None: linter = Linter("test.txt", contents) with raises: assert linter.line_for_pos(pos) == line - def test_fix(self): + def test_fix(self) -> None: linter = Linter("test.txt", "Hello world!") assert linter.fix() == "Hello world!" @@ -182,14 +183,14 @@ def bracket_file(self, tmp_path: str) -> Generator[TextIO, None, None]: yield f @contextlib.contextmanager - def mock_console(self): + def mock_console(self) -> Generator[Mock, None, None]: m = Mock() with patch("rich.console.Console", m), patch( "rapids_pre_commit_hooks.lint.Console", m ): yield m - def the_check(self, linter: Linter, args: argparse.Namespace): + def the_check(self, linter: Linter, args: argparse.Namespace) -> None: assert args.check_test linter.add_warning((0, 5), "say good bye instead").add_replacement( (0, 5), "Good bye" @@ -197,25 +198,25 @@ def the_check(self, linter: Linter, args: argparse.Namespace): if linter.content[5] != "!": linter.add_warning((5, 5), "use punctuation").add_replacement((5, 5), ",") - def long_file_check(self, linter: Linter, args: argparse.Namespace): + def long_file_check(self, linter: Linter, args: argparse.Namespace) -> None: linter.add_warning((0, len(linter.content)), "this is a long file") - def long_fix_check(self, linter: Linter, args: argparse.Namespace): + def long_fix_check(self, linter: Linter, args: argparse.Namespace) -> None: linter.add_warning((0, 19), "this is a long line").add_replacement( (0, 19), "This is a long file\nIt's even longer now" ) - def long_delete_fix_check(self, linter: Linter, args: argparse.Namespace): + def long_delete_fix_check(self, linter: Linter, args: argparse.Namespace) -> None: linter.add_warning( (0, len(linter.content)), "this is a long file" ).add_replacement((0, len(linter.content)), "This is a short file now") - def bracket_check(self, linter: Linter, args: argparse.Namespace): + def bracket_check(self, linter: Linter, args: argparse.Namespace) -> None: linter.add_warning((0, 28), "this [file] has brackets").add_replacement( (12, 17), "[has more]" ) - def test_no_warnings_no_fix(self, hello_world_file: TextIO): + def test_no_warnings_no_fix(self, hello_world_file: TextIO) -> None: with patch( "sys.argv", ["check-test", "--check-test", hello_world_file.name] ), self.mock_console() as console: @@ -228,7 +229,7 @@ def test_no_warnings_no_fix(self, hello_world_file: TextIO): call(highlight=False), ] - def test_no_warnings_fix(self, hello_world_file: TextIO): + def test_no_warnings_fix(self, hello_world_file: TextIO) -> None: with patch( "sys.argv", ["check-test", "--check-test", "--fix", hello_world_file.name] ), self.mock_console() as console: @@ -241,7 +242,7 @@ def test_no_warnings_fix(self, hello_world_file: TextIO): call(highlight=False), ] - def test_warnings_no_fix(self, hello_world_file: TextIO): + def test_warnings_no_fix(self, hello_world_file: TextIO) -> None: with patch( "sys.argv", ["check-test", "--check-test", hello_world_file.name] ), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"): @@ -272,7 +273,7 @@ def test_warnings_no_fix(self, hello_world_file: TextIO): call().print(), ] - def test_warnings_fix(self, hello_world_file: TextIO): + def test_warnings_fix(self, hello_world_file: TextIO) -> None: with patch( "sys.argv", ["check-test", "--check-test", "--fix", hello_world_file.name] ), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"): @@ -303,7 +304,7 @@ def test_warnings_fix(self, hello_world_file: TextIO): call().print(), ] - def test_multiple_files(self, hello_world_file: TextIO, hello_file: TextIO): + def test_multiple_files(self, hello_world_file: TextIO, hello_file: TextIO) -> None: with patch( "sys.argv", [ @@ -352,7 +353,7 @@ def test_multiple_files(self, hello_world_file: TextIO, hello_file: TextIO): call().print(), ] - def test_binary_file(self, binary_file: BinaryIO): + def test_binary_file(self, binary_file: BinaryIO) -> None: mock_linter = Mock(wraps=Linter) with patch( "sys.argv", @@ -372,7 +373,7 @@ def test_binary_file(self, binary_file: BinaryIO): ctx.add_check(self.the_check) mock_linter.assert_not_called() - def test_long_file(self, long_file: TextIO): + def test_long_file(self, long_file: TextIO) -> None: with patch( "sys.argv", [ @@ -384,11 +385,11 @@ def test_long_file(self, long_file: TextIO): with m.execute() as ctx: ctx.add_check(self.long_file_check) ctx.add_check(self.long_fix_check) - assert ( - long_file.read() - == """This is a long file -It has multiple lines -""" + assert long_file.read() == dedent( + """\ + This is a long file + It has multiple lines + """ ) assert console.mock_calls == [ call(highlight=False), @@ -410,7 +411,7 @@ def test_long_file(self, long_file: TextIO): call().print(), ] - def test_long_file_delete(self, long_file: TextIO): + def test_long_file_delete(self, long_file: TextIO) -> None: with patch( "sys.argv", [ @@ -421,11 +422,11 @@ def test_long_file_delete(self, long_file: TextIO): m = LintMain() with m.execute() as ctx: ctx.add_check(self.long_delete_fix_check) - assert ( - long_file.read() - == """This is a long file -It has multiple lines -""" + assert long_file.read() == dedent( + """\ + This is a long file + It has multiple lines + """ ) assert console.mock_calls == [ call(highlight=False), @@ -443,7 +444,7 @@ def test_long_file_delete(self, long_file: TextIO): call().print(), ] - def test_long_file_fix(self, long_file: TextIO): + def test_long_file_fix(self, long_file: TextIO) -> None: with patch( "sys.argv", [ @@ -456,12 +457,12 @@ def test_long_file_fix(self, long_file: TextIO): with m.execute() as ctx: ctx.add_check(self.long_file_check) ctx.add_check(self.long_fix_check) - assert ( - long_file.read() - == """This is a long file -It's even longer now -It has multiple lines -""" + assert long_file.read() == dedent( + """\ + This is a long file + It's even longer now + It has multiple lines + """ ) assert console.mock_calls == [ call(highlight=False), @@ -482,7 +483,7 @@ def test_long_file_fix(self, long_file: TextIO): call().print(), ] - def test_long_file_delete_fix(self, long_file: TextIO): + def test_long_file_delete_fix(self, long_file: TextIO) -> None: with patch( "sys.argv", [ @@ -510,7 +511,7 @@ def test_long_file_delete_fix(self, long_file: TextIO): call().print(), ] - def test_bracket_file(self, bracket_file: TextIO): + def test_bracket_file(self, bracket_file: TextIO) -> None: with patch( "sys.argv", [ diff --git a/test/rapids_pre_commit_hooks/test_pyproject_license.py b/test/rapids_pre_commit_hooks/test_pyproject_license.py index 6120ebc..fa24ac9 100644 --- a/test/rapids_pre_commit_hooks/test_pyproject_license.py +++ b/test/rapids_pre_commit_hooks/test_pyproject_license.py @@ -54,7 +54,7 @@ ), ], ) -def test_find_value_location(key: tuple[str, ...], append: bool, loc: _LocType): +def test_find_value_location(key: tuple[str, ...], append: bool, loc: _LocType) -> None: CONTENT = dedent( """\ [table] @@ -187,7 +187,7 @@ def test_check_pyproject_license( message: str, replacement_loc: _LocType, replacement_text: str, -): +) -> None: linter = Linter("pyproject.toml", document) pyproject_license.check_pyproject_license(linter, Mock()) diff --git a/test/rapids_pre_commit_hooks/test_shell.py b/test/rapids_pre_commit_hooks/test_shell.py index 1d2996b..50dbec3 100644 --- a/test/rapids_pre_commit_hooks/test_shell.py +++ b/test/rapids_pre_commit_hooks/test_shell.py @@ -27,7 +27,7 @@ def run_shell_linter(content: str, cls: type) -> Linter: return linter -def test_verify_conda_yes(): +def test_verify_conda_yes() -> None: CONTENT = r""" conda install -y pkg1 conda install --yes pkg2 pkg3 diff --git a/test/test_pre_commit.py b/test/test_pre_commit.py index 9809e55..2fa450f 100644 --- a/test/test_pre_commit.py +++ b/test/test_pre_commit.py @@ -60,7 +60,7 @@ def git_repo(tmp_path: str) -> git.Repo: def run_pre_commit( git_repo: git.Repo, hook_name: str, expected_status: str, exc: Optional[type] -): +) -> None: assert git_repo.working_tree_dir is not None def list_files(top: str) -> Generator[str, None, None]: @@ -110,7 +110,7 @@ def list_files(top: str) -> Generator[str, None, None]: "hook_name", ALL_HOOKS, ) -def test_pre_commit_pass(git_repo: git.Repo, hook_name: str): +def test_pre_commit_pass(git_repo: git.Repo, hook_name: str) -> None: run_pre_commit(git_repo, hook_name, "pass", None) @@ -118,5 +118,5 @@ def test_pre_commit_pass(git_repo: git.Repo, hook_name: str): "hook_name", ALL_HOOKS, ) -def test_pre_commit_fail(git_repo: git.Repo, hook_name: str): +def test_pre_commit_fail(git_repo: git.Repo, hook_name: str) -> None: run_pre_commit(git_repo, hook_name, "fail", subprocess.CalledProcessError) From 8917b35c2fe51dae112c303b42e443e7dcd7c20e Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 26 Aug 2024 13:23:01 -0400 Subject: [PATCH 13/20] Type assert --- src/rapids_pre_commit_hooks/alpha_spec.py | 6 ++---- src/rapids_pre_commit_hooks/lint.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/rapids_pre_commit_hooks/alpha_spec.py b/src/rapids_pre_commit_hooks/alpha_spec.py index 4d83467..ae4fe2a 100644 --- a/src/rapids_pre_commit_hooks/alpha_spec.py +++ b/src/rapids_pre_commit_hooks/alpha_spec.py @@ -90,13 +90,11 @@ def __init__(self, spec: str): self.spec: str = spec def __eq__(self, other: object) -> bool: - if not isinstance(other, SpecPriority): - return False + assert isinstance(other, SpecPriority) return self.spec == other.spec def __lt__(self, other: object) -> bool: - if not isinstance(other, SpecPriority): - return False + assert isinstance(other, SpecPriority) if self.spec == other.spec: return False if self.spec == ALPHA_SPECIFIER: diff --git a/src/rapids_pre_commit_hooks/lint.py b/src/rapids_pre_commit_hooks/lint.py index e98851d..ebf0110 100644 --- a/src/rapids_pre_commit_hooks/lint.py +++ b/src/rapids_pre_commit_hooks/lint.py @@ -53,8 +53,7 @@ def __init__(self, pos: _PosType, newtext: str) -> None: self.newtext: str = newtext def __eq__(self, other: object) -> bool: - if not isinstance(other, Replacement): - return False + assert isinstance(other, Replacement) return self.pos == other.pos and self.newtext == other.newtext def __repr__(self) -> str: @@ -71,8 +70,7 @@ def add_replacement(self, pos: _PosType, newtext: str) -> None: self.replacements.append(Replacement(pos, newtext)) def __eq__(self, other: object) -> bool: - if not isinstance(other, LintWarning): - return False + assert isinstance(other, LintWarning) return ( self.pos == other.pos and self.msg == other.msg From 62f24b6be919f3fa3689a44223d3adb851f6a2f8 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 26 Aug 2024 13:33:39 -0400 Subject: [PATCH 14/20] Make type: ignore more specific --- src/rapids_pre_commit_hooks/alpha_spec.py | 2 +- .../pyproject_license.py | 6 +- .../rapids_pre_commit_hooks/test_copyright.py | 56 ++++++++++--------- test/test_pre_commit.py | 2 +- 4 files changed, 36 insertions(+), 30 deletions(-) diff --git a/src/rapids_pre_commit_hooks/alpha_spec.py b/src/rapids_pre_commit_hooks/alpha_spec.py index ae4fe2a..c48d474 100644 --- a/src/rapids_pre_commit_hooks/alpha_spec.py +++ b/src/rapids_pre_commit_hooks/alpha_spec.py @@ -275,7 +275,7 @@ def compose_document(self) -> yaml.Node: self.get_event() # Compose the root node. - node = self.compose_node(None, None) # type: ignore + node = self.compose_node(None, None) # type: ignore[arg-type] # Drop the DOCUMENT-END event. self.get_event() diff --git a/src/rapids_pre_commit_hooks/pyproject_license.py b/src/rapids_pre_commit_hooks/pyproject_license.py index d8622cf..602816a 100644 --- a/src/rapids_pre_commit_hooks/pyproject_license.py +++ b/src/rapids_pre_commit_hooks/pyproject_license.py @@ -44,7 +44,7 @@ def find_value_location( # look for that in the new document. node = copied_document while len(key) > (0 if append else 1): - node = node[key[0]] # type: ignore + node = node[key[0]] # type: ignore[assignment] key = key[1:] if append: node.add(str(placeholder), placeholder_toml) @@ -65,8 +65,8 @@ def check_pyproject_license(linter: Linter, args: argparse.Namespace) -> None: try: add_project_table = True project_table = document["project"] - add_project_table = project_table.is_super_table() # type: ignore - license_value = project_table["license"]["text"] # type: ignore + add_project_table = project_table.is_super_table() # type: ignore[union-attr] + license_value = project_table["license"]["text"] # type: ignore[index] except tomlkit.exceptions.NonExistentKey: if add_project_table: loc = (len(linter.content), len(linter.content)) diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index be095b0..ea920f1 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -318,7 +318,7 @@ def mock_target_branch(branch: Any): remote_1_branch_1 = remote_repo_1.create_head( "branch-1-renamed", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_1 # type: ignore + remote_repo_1.head.reference = remote_1_branch_1 # type: ignore[misc] remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file1.txt", "File 1 modified") remote_repo_1.index.add(["file1.txt"]) @@ -330,7 +330,7 @@ def mock_target_branch(branch: Any): remote_1_branch_2 = remote_repo_1.create_head( "branch-2", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_2 # type: ignore + remote_repo_1.head.reference = remote_1_branch_2 # type: ignore[misc] remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file2.txt", "File 2 modified") remote_repo_1.index.add(["file2.txt"]) @@ -339,7 +339,7 @@ def mock_target_branch(branch: Any): remote_1_branch_3 = remote_repo_1.create_head( "branch-3", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_3 # type: ignore + remote_repo_1.head.reference = remote_1_branch_3 # type: ignore[misc] remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file3.txt", "File 3 modified") remote_repo_1.index.add(["file3.txt"]) @@ -351,7 +351,7 @@ def mock_target_branch(branch: Any): remote_1_branch_4 = remote_repo_1.create_head( "branch-4", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_4 # type: ignore + remote_repo_1.head.reference = remote_1_branch_4 # type: ignore[misc] remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file4.txt", "File 4 modified") remote_repo_1.index.add(["file4.txt"]) @@ -363,7 +363,7 @@ def mock_target_branch(branch: Any): remote_1_branch_7 = remote_repo_1.create_head( "branch-7", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_7 # type: ignore + remote_repo_1.head.reference = remote_1_branch_7 # type: ignore[misc] remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file7.txt", "File 7 modified") remote_repo_1.index.add(["file7.txt"]) @@ -379,7 +379,7 @@ def mock_target_branch(branch: Any): remote_2_branch_3 = remote_repo_2.create_head( "branch-3", remote_2_master.commit ) - remote_repo_2.head.reference = remote_2_branch_3 # type: ignore + remote_repo_2.head.reference = remote_2_branch_3 # type: ignore[misc] remote_repo_2.head.reset(index=True, working_tree=True) write_file(remote_repo_2, "file3.txt", "File 3 modified") remote_repo_2.index.add(["file3.txt"]) @@ -391,7 +391,7 @@ def mock_target_branch(branch: Any): remote_2_branch_4 = remote_repo_2.create_head( "branch-4", remote_2_master.commit ) - remote_repo_2.head.reference = remote_2_branch_4 # type: ignore + remote_repo_2.head.reference = remote_2_branch_4 # type: ignore[misc] remote_repo_2.head.reset(index=True, working_tree=True) write_file(remote_repo_2, "file4.txt", "File 4 modified") remote_repo_2.index.add(["file4.txt"]) @@ -403,7 +403,7 @@ def mock_target_branch(branch: Any): remote_2_branch_5 = remote_repo_2.create_head( "branch-5", remote_2_master.commit ) - remote_repo_2.head.reference = remote_2_branch_5 # type: ignore + remote_repo_2.head.reference = remote_2_branch_5 # type: ignore[misc] remote_repo_2.head.reset(index=True, working_tree=True) write_file(remote_repo_2, "file5.txt", "File 5 modified") remote_repo_2.index.add(["file5.txt"]) @@ -433,7 +433,7 @@ def mock_target_branch(branch: Any): with branch_1.config_writer() as w: w.set_value("remote", "unconventional/remote/name/1") w.set_value("merge", "branch-1-renamed") - git_repo.head.reference = branch_1 # type: ignore + git_repo.head.reference = branch_1 # type: ignore[misc] git_repo.head.reset(index=True, working_tree=True) git_repo.index.remove("file1.txt", working_tree=True) git_repo.index.commit( @@ -442,7 +442,7 @@ def mock_target_branch(branch: Any): ) branch_6 = git_repo.create_head("branch-6", remote_1.refs["master"]) - git_repo.head.reference = branch_6 # type: ignore + git_repo.head.reference = branch_6 # type: ignore[misc] git_repo.head.reset(index=True, working_tree=True) git_repo.index.remove(["file6.txt"], working_tree=True) git_repo.index.commit("Remove file6.txt") @@ -451,7 +451,7 @@ def mock_target_branch(branch: Any): with branch_7.config_writer() as w: w.set_value("remote", "unconventional/remote/name/1") w.set_value("merge", "branch-7") - git_repo.head.reference = branch_7 # type: ignore + git_repo.head.reference = branch_7 # type: ignore[misc] git_repo.head.reset(index=True, working_tree=True) git_repo.index.remove(["file7.txt"], working_tree=True) git_repo.index.commit( @@ -459,7 +459,7 @@ def mock_target_branch(branch: Any): commit_date=datetime.datetime(2024, 2, 1, tzinfo=datetime.timezone.utc), ) - git_repo.head.reference = main # type: ignore + git_repo.head.reference = main # type: ignore[misc] git_repo.head.reset(index=True, working_tree=True) with mock_target_branch("branch-1"): @@ -614,7 +614,7 @@ def file_contents(verbed: str) -> str: git_repo.index.commit("Remove modified.txt") pr_branch = git_repo.create_head("pr", "HEAD~") - git_repo.head.reference = pr_branch # type: ignore + git_repo.head.reference = pr_branch # type: ignore[misc] git_repo.head.reset(index=True, working_tree=True) write_file("copied_2.txt", file_contents("copied")) @@ -662,7 +662,7 @@ def file_contents(verbed: str) -> str: target_branch = git_repo.heads["master"] merge_base = git_repo.merge_base(target_branch, "HEAD")[0] old_files = { - blob.path: blob # type: ignore + blob.path: blob # type: ignore[union-attr] for blob in merge_base.tree.traverse(lambda b, _: isinstance(b, git.Blob)) } @@ -698,17 +698,23 @@ def file_contents(verbed: str) -> str: if old: with open(fn(new), "rb") as f: new_contents = f.read() - old_contents = old_files[old].data_stream.read() # type: ignore + old_contents = old_files[old].data_stream.read() # type: ignore[union-attr] assert new_contents != old_contents - assert changed_files[new].data_stream.read() == old_contents # type: ignore + assert ( + changed_files[new].data_stream.read() # type: ignore[union-attr] + == old_contents + ) for new, old in superfluous.items(): if old: with open(fn(new), "rb") as f: new_contents = f.read() - old_contents = old_files[old].data_stream.read() # type: ignore + old_contents = old_files[old].data_stream.read() # type: ignore[union-attr] assert new_contents == old_contents - assert changed_files[new].data_stream.read() == old_contents # type: ignore + assert ( + changed_files[new].data_stream.read() # type: ignore[union-attr] + == old_contents + ) def test_get_changed_files_multiple_merge_bases(git_repo: git.Repo) -> None: @@ -727,7 +733,7 @@ def write_file(filename: str, contents: str): git_repo.index.commit("Initial commit") branch_1 = git_repo.create_head("branch-1", "master") - git_repo.head.reference = branch_1 # type: ignore + git_repo.head.reference = branch_1 # type: ignore[misc] git_repo.index.reset(index=True, working_tree=True) write_file("file1.txt", "File 1 modified\n") git_repo.index.add("file1.txt") @@ -737,7 +743,7 @@ def write_file(filename: str, contents: str): ) branch_2 = git_repo.create_head("branch-2", "master") - git_repo.head.reference = branch_2 # type: ignore + git_repo.head.reference = branch_2 # type: ignore[misc] git_repo.index.reset(index=True, working_tree=True) write_file("file2.txt", "File 2 modified\n") git_repo.index.add("file2.txt") @@ -747,7 +753,7 @@ def write_file(filename: str, contents: str): ) branch_1_2 = git_repo.create_head("branch-1-2", "master") - git_repo.head.reference = branch_1_2 # type: ignore + git_repo.head.reference = branch_1_2 # type: ignore[misc] git_repo.index.reset(index=True, working_tree=True) write_file("file1.txt", "File 1 modified\n") write_file("file2.txt", "File 2 modified\n") @@ -759,7 +765,7 @@ def write_file(filename: str, contents: str): ) branch_3 = git_repo.create_head("branch-3", "master") - git_repo.head.reference = branch_3 # type: ignore + git_repo.head.reference = branch_3 # type: ignore[misc] git_repo.index.reset(index=True, working_tree=True) write_file("file1.txt", "File 1 modified\n") write_file("file2.txt", "File 2 modified\n") @@ -873,21 +879,21 @@ def file_contents_modified(num: int) -> str: git_repo.index.commit("Initial commit") branch_1 = git_repo.create_head("branch-1", "master") - git_repo.head.reference = branch_1 # type: ignore + git_repo.head.reference = branch_1 # type: ignore[misc] git_repo.head.reset(index=True, working_tree=True) write_file("file1.txt", file_contents_modified(1)) git_repo.index.add(["file1.txt"]) git_repo.index.commit("Update file1.txt") branch_2 = git_repo.create_head("branch-2", "master") - git_repo.head.reference = branch_2 # type: ignore + git_repo.head.reference = branch_2 # type: ignore[misc] git_repo.head.reset(index=True, working_tree=True) write_file("file2.txt", file_contents_modified(2)) git_repo.index.add(["file2.txt"]) git_repo.index.commit("Update file2.txt") pr = git_repo.create_head("pr", "branch-1") - git_repo.head.reference = pr # type: ignore + git_repo.head.reference = pr # type: ignore[misc] git_repo.head.reset(index=True, working_tree=True) write_file("file3.txt", file_contents_modified(3)) git_repo.index.add(["file3.txt"]) diff --git a/test/test_pre_commit.py b/test/test_pre_commit.py index 2fa450f..f58dddc 100644 --- a/test/test_pre_commit.py +++ b/test/test_pre_commit.py @@ -86,7 +86,7 @@ def list_files(top: str) -> Generator[str, None, None]: branch_dir = os.path.join(example_dir, "branch") if os.path.exists(branch_dir): - git_repo.head.reference = git_repo.create_head( # type: ignore + git_repo.head.reference = git_repo.create_head( # type: ignore[misc] "branch", git_repo.head.commit ) git_repo.index.remove(list(list_files(master_dir)), working_tree=True) From f10af82285ca7676f2a7eca38e2fc9a38f33a2e5 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 26 Aug 2024 13:57:34 -0400 Subject: [PATCH 15/20] Add quote to third-party type hints --- src/rapids_pre_commit_hooks/alpha_spec.py | 36 +++++++++---------- src/rapids_pre_commit_hooks/copyright.py | 10 +++--- src/rapids_pre_commit_hooks/lint.py | 13 ++++--- .../pyproject_license.py | 2 +- 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/src/rapids_pre_commit_hooks/alpha_spec.py b/src/rapids_pre_commit_hooks/alpha_spec.py index c48d474..9d1e7d0 100644 --- a/src/rapids_pre_commit_hooks/alpha_spec.py +++ b/src/rapids_pre_commit_hooks/alpha_spec.py @@ -36,15 +36,15 @@ @cache -def all_metadata() -> RAPIDSMetadata: +def all_metadata() -> "RAPIDSMetadata": return fetch_latest() -def node_has_type(node: yaml.Node, tag_type: str) -> bool: +def node_has_type(node: "yaml.Node", tag_type: str) -> bool: return node.tag == f"tag:yaml.org,2002:{tag_type}" -def get_rapids_version(args: argparse.Namespace) -> RAPIDSVersion: +def get_rapids_version(args: argparse.Namespace) -> "RAPIDSVersion": md = all_metadata() return ( md.versions[args.rapids_version] @@ -62,7 +62,7 @@ def strip_cuda_suffix(args: argparse.Namespace, name: str) -> str: def check_and_mark_anchor( - anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node + anchors: dict[str, "yaml.Node"], used_anchors: set[str], node: "yaml.Node" ) -> tuple[bool, Optional[str]]: for key, value in anchors.items(): if value == node: @@ -80,9 +80,9 @@ def check_and_mark_anchor( def check_package_spec( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: @total_ordering class SpecPriority: @@ -154,9 +154,9 @@ def create_specifier_string(specifiers: set[str]) -> str: def check_packages( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: if node_has_type(node, "seq"): descend, _ = check_and_mark_anchor(anchors, used_anchors, node) @@ -168,9 +168,9 @@ def check_packages( def check_common( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: if node_has_type(node, "seq"): for dependency_set in node.value: @@ -188,9 +188,9 @@ def check_common( def check_matrices( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: if node_has_type(node, "seq"): for item in node.value: @@ -208,9 +208,9 @@ def check_matrices( def check_specific( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: if node_has_type(node, "seq"): for matrix_matcher in node.value: @@ -228,9 +228,9 @@ def check_specific( def check_dependencies( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: if node_has_type(node, "map"): for _, dependencies_value in node.value: @@ -250,9 +250,9 @@ def check_dependencies( def check_root( linter: Linter, args: argparse.Namespace, - anchors: dict[str, yaml.Node], + anchors: dict[str, "yaml.Node"], used_anchors: set[str], - node: yaml.Node, + node: "yaml.Node", ) -> None: if node_has_type(node, "map"): for root_key, root_value in node.value: diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 8f236c4..e282869 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -108,7 +108,7 @@ def apply_copyright_check(linter: Linter, old_content: Optional[str]) -> None: linter.add_warning((0, 0), "no copyright notice found") -def get_target_branch(repo: git.Repo, args: argparse.Namespace) -> Optional[str]: +def get_target_branch(repo: "git.Repo", args: argparse.Namespace) -> Optional[str]: """Determine which branch is the "target" branch. The target branch is determined in the following order: @@ -175,7 +175,7 @@ def get_target_branch(repo: git.Repo, args: argparse.Namespace) -> Optional[str] def get_target_branch_upstream_commit( - repo: git.Repo, args: argparse.Namespace + repo: "git.Repo", args: argparse.Namespace ) -> Optional[git.Commit]: # If no target branch can be determined, use HEAD if it exists target_branch_name = get_target_branch(repo, args) @@ -232,7 +232,7 @@ def try_get_ref(remote: git.Remote) -> Optional[git.Reference]: def get_changed_files( args: argparse.Namespace, -) -> dict[Union[str, os.PathLike[str]], Optional[git.Blob]]: +) -> dict[Union[str, os.PathLike[str]], Optional["git.Blob"]]: try: repo = git.Repo() except git.InvalidGitRepositoryError: @@ -276,8 +276,8 @@ def normalize_git_filename(filename: Union[str, os.PathLike[str]]) -> Optional[s def find_blob( - tree: git.Tree, filename: Union[str, os.PathLike[str]] -) -> Optional[git.Blob]: + tree: "git.Tree", filename: Union[str, os.PathLike[str]] +) -> Optional["git.Blob"]: d1, d2 = os.path.split(filename) split = [d2] while d1: diff --git a/src/rapids_pre_commit_hooks/lint.py b/src/rapids_pre_commit_hooks/lint.py index ebf0110..670a6a7 100644 --- a/src/rapids_pre_commit_hooks/lint.py +++ b/src/rapids_pre_commit_hooks/lint.py @@ -93,7 +93,7 @@ def __init__(self, filename: str, content: str) -> None: self.filename: str = filename self.content: str = content self.warnings: list[LintWarning] = [] - self.console: Console = Console(highlight=False) + self.console: "Console" = Console(highlight=False) self._calculate_lines() def add_warning(self, pos: _PosType, msg: str) -> LintWarning: @@ -207,16 +207,19 @@ def print_highlighted_code( def line_for_pos(self, index: int) -> int: @functools.total_ordering class LineComparator: - def __init__(self, pos: _PosType): + def __init__(self, pos: _PosType) -> None: self.pos: _PosType = pos - def __lt__(self, other): + def __lt__(self, other: object) -> bool: + assert isinstance(other, LineComparator) return self.pos[1] < other - def __gt__(self, other): + def __gt__(self, other: object) -> bool: + assert isinstance(other, LineComparator) return self.pos[0] > other - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + assert isinstance(other, LineComparator) return self.pos[0] <= other <= self.pos[1] line_index = bisect.bisect_left( diff --git a/src/rapids_pre_commit_hooks/pyproject_license.py b/src/rapids_pre_commit_hooks/pyproject_license.py index 602816a..cd16723 100644 --- a/src/rapids_pre_commit_hooks/pyproject_license.py +++ b/src/rapids_pre_commit_hooks/pyproject_license.py @@ -32,7 +32,7 @@ def find_value_location( - document: tomlkit.TOMLDocument, key: tuple[str, ...], append: bool + document: "tomlkit.TOMLDocument", key: tuple[str, ...], append: bool ) -> _LocType: copied_document = copy.deepcopy(document) placeholder = uuid.uuid4() From 71878889919dbd0f60ec40da745135f1c52d0f35 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 26 Aug 2024 14:07:42 -0400 Subject: [PATCH 16/20] Fix type assertion --- src/rapids_pre_commit_hooks/lint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rapids_pre_commit_hooks/lint.py b/src/rapids_pre_commit_hooks/lint.py index 670a6a7..d10cbff 100644 --- a/src/rapids_pre_commit_hooks/lint.py +++ b/src/rapids_pre_commit_hooks/lint.py @@ -211,15 +211,15 @@ def __init__(self, pos: _PosType) -> None: self.pos: _PosType = pos def __lt__(self, other: object) -> bool: - assert isinstance(other, LineComparator) + assert isinstance(other, int) return self.pos[1] < other def __gt__(self, other: object) -> bool: - assert isinstance(other, LineComparator) + assert isinstance(other, int) return self.pos[0] > other def __eq__(self, other: object) -> bool: - assert isinstance(other, LineComparator) + assert isinstance(other, int) return self.pos[0] <= other <= self.pos[1] line_index = bisect.bisect_left( From 293a9a87de3a219b9388cfb514d6d7290ab7388b Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 26 Aug 2024 14:18:18 -0400 Subject: [PATCH 17/20] Remove type hints from tests --- .../test_alpha_spec.py | 64 ++++----- .../rapids_pre_commit_hooks/test_copyright.py | 128 +++++++----------- test/rapids_pre_commit_hooks/test_lint.py | 62 ++++----- .../test_pyproject_license.py | 16 +-- test/rapids_pre_commit_hooks/test_shell.py | 4 +- 5 files changed, 118 insertions(+), 156 deletions(-) diff --git a/test/rapids_pre_commit_hooks/test_alpha_spec.py b/test/rapids_pre_commit_hooks/test_alpha_spec.py index 9412004..ce61ff8 100644 --- a/test/rapids_pre_commit_hooks/test_alpha_spec.py +++ b/test/rapids_pre_commit_hooks/test_alpha_spec.py @@ -16,7 +16,6 @@ import os.path from itertools import chain from textwrap import dedent -from typing import Iterator, Optional from unittest.mock import MagicMock, Mock, call, patch import pytest @@ -32,7 +31,7 @@ @contextlib.contextmanager -def set_cwd(cwd: os.PathLike[str]) -> Iterator: +def set_cwd(cwd): old_cwd = os.getcwd() os.chdir(cwd) try: @@ -53,12 +52,8 @@ def set_cwd(cwd: os.PathLike[str]) -> Iterator: ], ) def test_get_rapids_version( - tmp_path: os.PathLike, - version_file: Optional[str], - version_arg: Optional[str], - expected_version: Optional[str], - raises: contextlib.AbstractContextManager, -) -> None: + tmp_path, version_file, version_arg, expected_version, raises +): MOCK_METADATA = RAPIDSMetadata( versions={ "24.06": RAPIDSVersion( @@ -87,11 +82,10 @@ def test_get_rapids_version( assert version == MOCK_METADATA.versions[expected_version] -def test_anchor_preserving_loader() -> None: +def test_anchor_preserving_loader(): loader = alpha_spec.AnchorPreservingLoader("- &a A\n- *a") try: root = loader.get_single_node() - assert root is not None finally: loader.dispose() assert loader.document_anchors == [{"a": root.value[0]}] @@ -129,7 +123,7 @@ def test_anchor_preserving_loader() -> None: "rapids_pre_commit_hooks.alpha_spec.get_rapids_version", Mock(return_value=latest_metadata), ) -def test_strip_cuda_suffix(name: str, stripped_name: str) -> None: +def test_strip_cuda_suffix(name, stripped_name): assert alpha_spec.strip_cuda_suffix(Mock(), name) == stripped_name @@ -174,14 +168,14 @@ def test_strip_cuda_suffix(name: str, stripped_name: str) -> None: ], ) def test_check_and_mark_anchor( - used_anchors_before: set[str], - node_index: int, - descend: bool, - anchor: Optional[str], - used_anchors_after: set[str], -) -> None: + used_anchors_before, + node_index, + descend, + anchor, + used_anchors_after, +): NODES = [Mock() for _ in range(3)] - ANCHORS: dict[str, yaml.Node] = { + ANCHORS = { "anchor1": NODES[0], "anchor2": NODES[1], } @@ -252,15 +246,12 @@ def test_check_and_mark_anchor( "rapids_pre_commit_hooks.alpha_spec.get_rapids_version", Mock(return_value=latest_metadata), ) -def test_check_package_spec( - package: str, content: str, mode: str, replacement: str -) -> None: +def test_check_package_spec(package, content, mode, replacement): args = Mock(mode=mode) linter = lint.Linter("dependencies.yaml", content) loader = alpha_spec.AnchorPreservingLoader(content) try: composed = loader.get_single_node() - assert composed is not None finally: loader.dispose() alpha_spec.check_package_spec( @@ -284,7 +275,7 @@ def test_check_package_spec( "rapids_pre_commit_hooks.alpha_spec.get_rapids_version", Mock(return_value=latest_metadata), ) -def test_check_package_spec_anchor() -> None: +def test_check_package_spec_anchor(): CONTENT = dedent( """\ - &cudf cudf>=24.04,<24.06 @@ -298,10 +289,9 @@ def test_check_package_spec_anchor() -> None: loader = alpha_spec.AnchorPreservingLoader(CONTENT) try: composed = loader.get_single_node() - assert composed is not None finally: loader.dispose() - used_anchors: set[str] = set() + used_anchors = set() expected_linter = lint.Linter("dependencies.yaml", CONTENT) expected_linter.add_warning( @@ -359,7 +349,7 @@ def test_check_package_spec_anchor() -> None: ), ], ) -def test_check_packages(content: str, indices: list[int], use_anchor: bool) -> None: +def test_check_packages(content, indices, use_anchor): with patch( "rapids_pre_commit_hooks.alpha_spec.check_package_spec", Mock() ) as mock_check_package_spec: @@ -367,7 +357,7 @@ def test_check_packages(content: str, indices: list[int], use_anchor: bool) -> N linter = lint.Linter("dependencies.yaml", content) composed = yaml.compose(content) anchors = {"anchor": composed} - used_anchors: set[str] = set() + used_anchors = set() alpha_spec.check_packages(linter, args, anchors, used_anchors, composed) assert used_anchors == ({"anchor"} if use_anchor else set()) alpha_spec.check_packages(linter, args, anchors, used_anchors, composed) @@ -397,7 +387,7 @@ def test_check_packages(content: str, indices: list[int], use_anchor: bool) -> N ), ], ) -def test_check_common(content: str, indices: list[tuple[int, int]]) -> None: +def test_check_common(content, indices): with patch( "rapids_pre_commit_hooks.alpha_spec.check_packages", Mock() ) as mock_check_packages: @@ -432,7 +422,7 @@ def test_check_common(content: str, indices: list[tuple[int, int]]) -> None: ), ], ) -def test_check_matrices(content: str, indices: list[tuple[int, int]]) -> None: +def test_check_matrices(content, indices): with patch( "rapids_pre_commit_hooks.alpha_spec.check_packages", Mock() ) as mock_check_packages: @@ -478,7 +468,7 @@ def test_check_matrices(content: str, indices: list[tuple[int, int]]) -> None: ), ], ) -def test_check_specific(content: str, indices: list[tuple[int, int]]) -> None: +def test_check_specific(content, indices): with patch( "rapids_pre_commit_hooks.alpha_spec.check_matrices", Mock() ) as mock_check_matrices: @@ -532,10 +522,10 @@ def test_check_specific(content: str, indices: list[tuple[int, int]]) -> None: ], ) def test_check_dependencies( - content: str, - common_indices: list[tuple[int, int]], - specific_indices: list[tuple[int, int]], -) -> None: + content, + common_indices, + specific_indices, +): with patch( "rapids_pre_commit_hooks.alpha_spec.check_common", Mock() ) as mock_check_common, patch( @@ -572,7 +562,7 @@ def test_check_dependencies( ), ], ) -def test_check_root(content: str, indices: list[int]) -> None: +def test_check_root(content, indices): with patch( "rapids_pre_commit_hooks.alpha_spec.check_dependencies", Mock() ) as mock_check_dependencies: @@ -587,7 +577,7 @@ def test_check_root(content: str, indices: list[int]) -> None: ] -def test_check_alpha_spec() -> None: +def test_check_alpha_spec(): CONTENT = "dependencies: []" with patch( "rapids_pre_commit_hooks.alpha_spec.check_root", Mock() @@ -607,7 +597,7 @@ def test_check_alpha_spec() -> None: ) -def test_check_alpha_spec_integration(tmp_path: os.PathLike[str]) -> None: +def test_check_alpha_spec_integration(tmp_path): CONTENT = dedent( """\ dependencies: diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index ea920f1..340c9cc 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse import datetime import os.path import tempfile -from io import BufferedReader from textwrap import dedent -from typing import Any, Optional, TextIO, Union from unittest.mock import Mock, patch import git @@ -29,7 +26,7 @@ from rapids_pre_commit_hooks.lint import Linter -def test_match_copyright() -> None: +def test_match_copyright(): CONTENT = dedent( r""" Copyright (c) 2024 NVIDIA CORPORATION @@ -70,7 +67,7 @@ def test_match_copyright() -> None: ] -def test_strip_copyright() -> None: +def test_strip_copyright(): CONTENT = dedent( r""" This is a line before the first copyright statement @@ -96,10 +93,8 @@ def test_strip_copyright() -> None: @freeze_time("2024-01-18") -def test_apply_copyright_check() -> None: - def run_apply_copyright_check( - old_content: Optional[str], new_content: str - ) -> Linter: +def test_apply_copyright_check(): + def run_apply_copyright_check(old_content, new_content): linter = Linter("file.txt", new_content) copyright.apply_copyright_check(linter, old_content) return linter @@ -178,7 +173,7 @@ def run_apply_copyright_check( @pytest.fixture -def git_repo(tmp_path: os.PathLike[str]) -> git.Repo: +def git_repo(tmp_path): repo = git.Repo.init(tmp_path) with repo.config_writer() as w: w.set_value("user", "name", "RAPIDS Test Fixtures") @@ -186,9 +181,7 @@ def git_repo(tmp_path: os.PathLike[str]) -> git.Repo: return repo -def test_get_target_branch(git_repo: git.Repo) -> None: - assert git_repo.working_tree_dir is not None - +def test_get_target_branch(git_repo): with patch.dict("os.environ", {}, clear=True): args = Mock(main_branch=None, target_branch=None) @@ -271,16 +264,15 @@ def test_get_target_branch(git_repo: git.Repo) -> None: assert copyright.get_target_branch(git_repo, args) == "master" -def test_get_target_branch_upstream_commit(git_repo: git.Repo) -> None: - def fn(repo: git.Repo, filename: str) -> str: - assert repo.working_tree_dir is not None +def test_get_target_branch_upstream_commit(git_repo): + def fn(repo, filename): return os.path.join(repo.working_tree_dir, filename) - def write_file(repo: git.Repo, filename: str, contents: str): + def write_file(repo, filename, contents): with open(fn(repo, filename), "w") as f: f.write(contents) - def mock_target_branch(branch: Any): + def mock_target_branch(branch): return patch( "rapids_pre_commit_hooks.copyright.get_target_branch", Mock(return_value=branch), @@ -318,7 +310,7 @@ def mock_target_branch(branch: Any): remote_1_branch_1 = remote_repo_1.create_head( "branch-1-renamed", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_1 # type: ignore[misc] + remote_repo_1.head.reference = remote_1_branch_1 remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file1.txt", "File 1 modified") remote_repo_1.index.add(["file1.txt"]) @@ -330,7 +322,7 @@ def mock_target_branch(branch: Any): remote_1_branch_2 = remote_repo_1.create_head( "branch-2", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_2 # type: ignore[misc] + remote_repo_1.head.reference = remote_1_branch_2 remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file2.txt", "File 2 modified") remote_repo_1.index.add(["file2.txt"]) @@ -339,7 +331,7 @@ def mock_target_branch(branch: Any): remote_1_branch_3 = remote_repo_1.create_head( "branch-3", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_3 # type: ignore[misc] + remote_repo_1.head.reference = remote_1_branch_3 remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file3.txt", "File 3 modified") remote_repo_1.index.add(["file3.txt"]) @@ -351,7 +343,7 @@ def mock_target_branch(branch: Any): remote_1_branch_4 = remote_repo_1.create_head( "branch-4", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_4 # type: ignore[misc] + remote_repo_1.head.reference = remote_1_branch_4 remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file4.txt", "File 4 modified") remote_repo_1.index.add(["file4.txt"]) @@ -363,7 +355,7 @@ def mock_target_branch(branch: Any): remote_1_branch_7 = remote_repo_1.create_head( "branch-7", remote_1_master.commit ) - remote_repo_1.head.reference = remote_1_branch_7 # type: ignore[misc] + remote_repo_1.head.reference = remote_1_branch_7 remote_repo_1.head.reset(index=True, working_tree=True) write_file(remote_repo_1, "file7.txt", "File 7 modified") remote_repo_1.index.add(["file7.txt"]) @@ -379,7 +371,7 @@ def mock_target_branch(branch: Any): remote_2_branch_3 = remote_repo_2.create_head( "branch-3", remote_2_master.commit ) - remote_repo_2.head.reference = remote_2_branch_3 # type: ignore[misc] + remote_repo_2.head.reference = remote_2_branch_3 remote_repo_2.head.reset(index=True, working_tree=True) write_file(remote_repo_2, "file3.txt", "File 3 modified") remote_repo_2.index.add(["file3.txt"]) @@ -391,7 +383,7 @@ def mock_target_branch(branch: Any): remote_2_branch_4 = remote_repo_2.create_head( "branch-4", remote_2_master.commit ) - remote_repo_2.head.reference = remote_2_branch_4 # type: ignore[misc] + remote_repo_2.head.reference = remote_2_branch_4 remote_repo_2.head.reset(index=True, working_tree=True) write_file(remote_repo_2, "file4.txt", "File 4 modified") remote_repo_2.index.add(["file4.txt"]) @@ -403,7 +395,7 @@ def mock_target_branch(branch: Any): remote_2_branch_5 = remote_repo_2.create_head( "branch-5", remote_2_master.commit ) - remote_repo_2.head.reference = remote_2_branch_5 # type: ignore[misc] + remote_repo_2.head.reference = remote_2_branch_5 remote_repo_2.head.reset(index=True, working_tree=True) write_file(remote_repo_2, "file5.txt", "File 5 modified") remote_repo_2.index.add(["file5.txt"]) @@ -433,7 +425,7 @@ def mock_target_branch(branch: Any): with branch_1.config_writer() as w: w.set_value("remote", "unconventional/remote/name/1") w.set_value("merge", "branch-1-renamed") - git_repo.head.reference = branch_1 # type: ignore[misc] + git_repo.head.reference = branch_1 git_repo.head.reset(index=True, working_tree=True) git_repo.index.remove("file1.txt", working_tree=True) git_repo.index.commit( @@ -442,7 +434,7 @@ def mock_target_branch(branch: Any): ) branch_6 = git_repo.create_head("branch-6", remote_1.refs["master"]) - git_repo.head.reference = branch_6 # type: ignore[misc] + git_repo.head.reference = branch_6 git_repo.head.reset(index=True, working_tree=True) git_repo.index.remove(["file6.txt"], working_tree=True) git_repo.index.commit("Remove file6.txt") @@ -451,7 +443,7 @@ def mock_target_branch(branch: Any): with branch_7.config_writer() as w: w.set_value("remote", "unconventional/remote/name/1") w.set_value("merge", "branch-7") - git_repo.head.reference = branch_7 # type: ignore[misc] + git_repo.head.reference = branch_7 git_repo.head.reset(index=True, working_tree=True) git_repo.index.remove(["file7.txt"], working_tree=True) git_repo.index.commit( @@ -459,7 +451,7 @@ def mock_target_branch(branch: Any): commit_date=datetime.datetime(2024, 2, 1, tzinfo=datetime.timezone.utc), ) - git_repo.head.reference = main # type: ignore[misc] + git_repo.head.reference = main git_repo.head.reset(index=True, working_tree=True) with mock_target_branch("branch-1"): @@ -517,12 +509,8 @@ def mock_target_branch(branch: Any): ) -def test_get_changed_files(git_repo: git.Repo) -> None: - f: Union[BufferedReader, TextIO] - - assert git_repo.working_tree_dir is not None - - def mock_os_walk(top: Union[str, os.PathLike[str]]): +def test_get_changed_files(git_repo): + def mock_os_walk(top): return patch( "os.walk", Mock( @@ -553,15 +541,14 @@ def mock_os_walk(top: Union[str, os.PathLike[str]]): "subdir1/subdir2/sub.txt": None, } - def fn(filename: str) -> str: - assert git_repo.working_tree_dir is not None + def fn(filename): return os.path.join(git_repo.working_tree_dir, filename) - def write_file(filename: str, contents: str): + def write_file(filename, contents): with open(fn(filename), "w") as f: f.write(contents) - def file_contents(verbed: str) -> str: + def file_contents(verbed): return f"This file will be {verbed}\n" * 100 write_file("untouched.txt", file_contents("untouched")) @@ -614,7 +601,7 @@ def file_contents(verbed: str) -> str: git_repo.index.commit("Remove modified.txt") pr_branch = git_repo.create_head("pr", "HEAD~") - git_repo.head.reference = pr_branch # type: ignore[misc] + git_repo.head.reference = pr_branch git_repo.head.reset(index=True, working_tree=True) write_file("copied_2.txt", file_contents("copied")) @@ -662,7 +649,7 @@ def file_contents(verbed: str) -> str: target_branch = git_repo.heads["master"] merge_base = git_repo.merge_base(target_branch, "HEAD")[0] old_files = { - blob.path: blob # type: ignore[union-attr] + blob.path: blob for blob in merge_base.tree.traverse(lambda b, _: isinstance(b, git.Blob)) } @@ -698,31 +685,24 @@ def file_contents(verbed: str) -> str: if old: with open(fn(new), "rb") as f: new_contents = f.read() - old_contents = old_files[old].data_stream.read() # type: ignore[union-attr] + old_contents = old_files[old].data_stream.read() assert new_contents != old_contents - assert ( - changed_files[new].data_stream.read() # type: ignore[union-attr] - == old_contents - ) + assert changed_files[new].data_stream.read() == old_contents for new, old in superfluous.items(): if old: with open(fn(new), "rb") as f: new_contents = f.read() - old_contents = old_files[old].data_stream.read() # type: ignore[union-attr] + old_contents = old_files[old].data_stream.read() assert new_contents == old_contents - assert ( - changed_files[new].data_stream.read() # type: ignore[union-attr] - == old_contents - ) + assert changed_files[new].data_stream.read() == old_contents -def test_get_changed_files_multiple_merge_bases(git_repo: git.Repo) -> None: - def fn(filename: str) -> str: - assert git_repo.working_tree_dir is not None +def test_get_changed_files_multiple_merge_bases(git_repo): + def fn(filename): return os.path.join(git_repo.working_tree_dir, filename) - def write_file(filename: str, contents: str): + def write_file(filename, contents): with open(fn(filename), "w") as f: f.write(contents) @@ -733,7 +713,7 @@ def write_file(filename: str, contents: str): git_repo.index.commit("Initial commit") branch_1 = git_repo.create_head("branch-1", "master") - git_repo.head.reference = branch_1 # type: ignore[misc] + git_repo.head.reference = branch_1 git_repo.index.reset(index=True, working_tree=True) write_file("file1.txt", "File 1 modified\n") git_repo.index.add("file1.txt") @@ -743,7 +723,7 @@ def write_file(filename: str, contents: str): ) branch_2 = git_repo.create_head("branch-2", "master") - git_repo.head.reference = branch_2 # type: ignore[misc] + git_repo.head.reference = branch_2 git_repo.index.reset(index=True, working_tree=True) write_file("file2.txt", "File 2 modified\n") git_repo.index.add("file2.txt") @@ -753,7 +733,7 @@ def write_file(filename: str, contents: str): ) branch_1_2 = git_repo.create_head("branch-1-2", "master") - git_repo.head.reference = branch_1_2 # type: ignore[misc] + git_repo.head.reference = branch_1_2 git_repo.index.reset(index=True, working_tree=True) write_file("file1.txt", "File 1 modified\n") write_file("file2.txt", "File 2 modified\n") @@ -765,7 +745,7 @@ def write_file(filename: str, contents: str): ) branch_3 = git_repo.create_head("branch-3", "master") - git_repo.head.reference = branch_3 # type: ignore[misc] + git_repo.head.reference = branch_3 git_repo.index.reset(index=True, working_tree=True) write_file("file1.txt", "File 1 modified\n") write_file("file2.txt", "File 2 modified\n") @@ -797,7 +777,7 @@ def write_file(filename: str, contents: str): } -def test_normalize_git_filename() -> None: +def test_normalize_git_filename(): assert copyright.normalize_git_filename("file.txt") == "file.txt" assert copyright.normalize_git_filename("sub/file.txt") == "sub/file.txt" assert copyright.normalize_git_filename("sub//file.txt") == "sub/file.txt" @@ -825,9 +805,7 @@ def test_normalize_git_filename() -> None: ("nonexistent/sub.txt", False), ], ) -def test_find_blob(git_repo: git.Repo, path: str, present: bool) -> None: - assert git_repo.working_tree_dir is not None - +def test_find_blob(git_repo, path, present): with open(os.path.join(git_repo.working_tree_dir, "top.txt"), "w"): pass os.mkdir(os.path.join(git_repo.working_tree_dir, "sub1")) @@ -839,23 +817,21 @@ def test_find_blob(git_repo: git.Repo, path: str, present: bool) -> None: blob = copyright.find_blob(git_repo.head.commit.tree, path) if present: - assert blob is not None assert blob.path == path else: assert blob is None @freeze_time("2024-01-18") -def test_check_copyright(git_repo: git.Repo) -> None: - def fn(filename: str) -> str: - assert git_repo.working_tree_dir is not None +def test_check_copyright(git_repo): + def fn(filename): return os.path.join(git_repo.working_tree_dir, filename) - def write_file(filename: str, contents: str): + def write_file(filename, contents): with open(fn(filename), "w") as f: f.write(contents) - def file_contents(num: int) -> str: + def file_contents(num): return dedent( rf"""\ Copyright (c) 2021-2023 NVIDIA CORPORATION @@ -863,7 +839,7 @@ def file_contents(num: int) -> str: """ ) - def file_contents_modified(num: int) -> str: + def file_contents_modified(num): return dedent( rf"""\ Copyright (c) 2021-2023 NVIDIA CORPORATION @@ -879,21 +855,21 @@ def file_contents_modified(num: int) -> str: git_repo.index.commit("Initial commit") branch_1 = git_repo.create_head("branch-1", "master") - git_repo.head.reference = branch_1 # type: ignore[misc] + git_repo.head.reference = branch_1 git_repo.head.reset(index=True, working_tree=True) write_file("file1.txt", file_contents_modified(1)) git_repo.index.add(["file1.txt"]) git_repo.index.commit("Update file1.txt") branch_2 = git_repo.create_head("branch-2", "master") - git_repo.head.reference = branch_2 # type: ignore[misc] + git_repo.head.reference = branch_2 git_repo.head.reset(index=True, working_tree=True) write_file("file2.txt", file_contents_modified(2)) git_repo.index.add(["file2.txt"]) git_repo.index.commit("Update file2.txt") pr = git_repo.create_head("pr", "branch-1") - git_repo.head.reference = pr # type: ignore[misc] + git_repo.head.reference = pr git_repo.head.reset(index=True, working_tree=True) write_file("file3.txt", file_contents_modified(3)) git_repo.index.add(["file3.txt"]) @@ -909,8 +885,8 @@ def file_contents_modified(num: int) -> str: def mock_repo_cwd(): return patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)) - def mock_target_branch_upstream_commit(target_branch: str): - def func(repo: git.Repo, args: argparse.Namespace) -> git.Commit: + def mock_target_branch_upstream_commit(target_branch): + def func(repo, args): assert target_branch == args.target_branch return repo.heads[target_branch].commit diff --git a/test/rapids_pre_commit_hooks/test_lint.py b/test/rapids_pre_commit_hooks/test_lint.py index 53687a6..285b5c4 100644 --- a/test/rapids_pre_commit_hooks/test_lint.py +++ b/test/rapids_pre_commit_hooks/test_lint.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse import contextlib import os.path from textwrap import dedent -from typing import BinaryIO, Generator, TextIO from unittest.mock import Mock, call, patch import pytest @@ -30,12 +28,12 @@ class TestLinter: - LONG_CONTENTS: str = ( + LONG_CONTENTS = ( "line 1\nline 2\rline 3\r\nline 4\r\n\nline 6\r\n\r\nline 8\n\r\n" "line 10\r\r\nline 12\r\n\rline 14\n\nline 16\r\rline 18\n\rline 20" ) - def test_lines(self) -> None: + def test_lines(self): linter = Linter("test.txt", self.LONG_CONTENTS) assert linter.lines == [ (0, 6), @@ -108,16 +106,16 @@ def test_lines(self) -> None: ) def test_line_for_pos( self, - contents: str, - pos: int, - line: int, - raises: contextlib.AbstractContextManager, - ) -> None: + contents, + pos, + line, + raises, + ): linter = Linter("test.txt", contents) with raises: assert linter.line_for_pos(pos) == line - def test_fix(self) -> None: + def test_fix(self): linter = Linter("test.txt", "Hello world!") assert linter.fix() == "Hello world!" @@ -143,7 +141,7 @@ def test_fix(self) -> None: class TestLintMain: @pytest.fixture - def hello_world_file(self, tmp_path: str) -> Generator[TextIO, None, None]: + def hello_world_file(self, tmp_path): with open(os.path.join(tmp_path, "hello_world.txt"), "w+") as f: f.write("Hello world!") f.flush() @@ -151,7 +149,7 @@ def hello_world_file(self, tmp_path: str) -> Generator[TextIO, None, None]: yield f @pytest.fixture - def hello_file(self, tmp_path: str) -> Generator[TextIO, None, None]: + def hello_file(self, tmp_path): with open(os.path.join(tmp_path, "hello.txt"), "w+") as f: f.write("Hello!") f.flush() @@ -159,7 +157,7 @@ def hello_file(self, tmp_path: str) -> Generator[TextIO, None, None]: yield f @pytest.fixture - def binary_file(self, tmp_path: str) -> Generator[BinaryIO, None, None]: + def binary_file(self, tmp_path): with open(os.path.join(tmp_path, "binary.bin"), "wb+") as f: f.write(b"\xDE\xAD\xBE\xEF") f.flush() @@ -167,7 +165,7 @@ def binary_file(self, tmp_path: str) -> Generator[BinaryIO, None, None]: yield f @pytest.fixture - def long_file(self, tmp_path: str) -> Generator[TextIO, None, None]: + def long_file(self, tmp_path): with open(os.path.join(tmp_path, "long.txt"), "w+") as f: f.write("This is a long file\nIt has multiple lines\n") f.flush() @@ -175,7 +173,7 @@ def long_file(self, tmp_path: str) -> Generator[TextIO, None, None]: yield f @pytest.fixture - def bracket_file(self, tmp_path: str) -> Generator[TextIO, None, None]: + def bracket_file(self, tmp_path): with open(os.path.join(tmp_path, "file[with]brackets.txt"), "w+") as f: f.write("This [file] [has] [brackets]\n") f.flush() @@ -183,14 +181,14 @@ def bracket_file(self, tmp_path: str) -> Generator[TextIO, None, None]: yield f @contextlib.contextmanager - def mock_console(self) -> Generator[Mock, None, None]: + def mock_console(self): m = Mock() with patch("rich.console.Console", m), patch( "rapids_pre_commit_hooks.lint.Console", m ): yield m - def the_check(self, linter: Linter, args: argparse.Namespace) -> None: + def the_check(self, linter, args): assert args.check_test linter.add_warning((0, 5), "say good bye instead").add_replacement( (0, 5), "Good bye" @@ -198,25 +196,25 @@ def the_check(self, linter: Linter, args: argparse.Namespace) -> None: if linter.content[5] != "!": linter.add_warning((5, 5), "use punctuation").add_replacement((5, 5), ",") - def long_file_check(self, linter: Linter, args: argparse.Namespace) -> None: + def long_file_check(self, linter, args): linter.add_warning((0, len(linter.content)), "this is a long file") - def long_fix_check(self, linter: Linter, args: argparse.Namespace) -> None: + def long_fix_check(self, linter, args): linter.add_warning((0, 19), "this is a long line").add_replacement( (0, 19), "This is a long file\nIt's even longer now" ) - def long_delete_fix_check(self, linter: Linter, args: argparse.Namespace) -> None: + def long_delete_fix_check(self, linter, args): linter.add_warning( (0, len(linter.content)), "this is a long file" ).add_replacement((0, len(linter.content)), "This is a short file now") - def bracket_check(self, linter: Linter, args: argparse.Namespace) -> None: + def bracket_check(self, linter, args): linter.add_warning((0, 28), "this [file] has brackets").add_replacement( (12, 17), "[has more]" ) - def test_no_warnings_no_fix(self, hello_world_file: TextIO) -> None: + def test_no_warnings_no_fix(self, hello_world_file): with patch( "sys.argv", ["check-test", "--check-test", hello_world_file.name] ), self.mock_console() as console: @@ -229,7 +227,7 @@ def test_no_warnings_no_fix(self, hello_world_file: TextIO) -> None: call(highlight=False), ] - def test_no_warnings_fix(self, hello_world_file: TextIO) -> None: + def test_no_warnings_fix(self, hello_world_file): with patch( "sys.argv", ["check-test", "--check-test", "--fix", hello_world_file.name] ), self.mock_console() as console: @@ -242,7 +240,7 @@ def test_no_warnings_fix(self, hello_world_file: TextIO) -> None: call(highlight=False), ] - def test_warnings_no_fix(self, hello_world_file: TextIO) -> None: + def test_warnings_no_fix(self, hello_world_file): with patch( "sys.argv", ["check-test", "--check-test", hello_world_file.name] ), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"): @@ -273,7 +271,7 @@ def test_warnings_no_fix(self, hello_world_file: TextIO) -> None: call().print(), ] - def test_warnings_fix(self, hello_world_file: TextIO) -> None: + def test_warnings_fix(self, hello_world_file): with patch( "sys.argv", ["check-test", "--check-test", "--fix", hello_world_file.name] ), self.mock_console() as console, pytest.raises(SystemExit, match=r"^1$"): @@ -304,7 +302,7 @@ def test_warnings_fix(self, hello_world_file: TextIO) -> None: call().print(), ] - def test_multiple_files(self, hello_world_file: TextIO, hello_file: TextIO) -> None: + def test_multiple_files(self, hello_world_file, hello_file): with patch( "sys.argv", [ @@ -353,7 +351,7 @@ def test_multiple_files(self, hello_world_file: TextIO, hello_file: TextIO) -> N call().print(), ] - def test_binary_file(self, binary_file: BinaryIO) -> None: + def test_binary_file(self, binary_file): mock_linter = Mock(wraps=Linter) with patch( "sys.argv", @@ -373,7 +371,7 @@ def test_binary_file(self, binary_file: BinaryIO) -> None: ctx.add_check(self.the_check) mock_linter.assert_not_called() - def test_long_file(self, long_file: TextIO) -> None: + def test_long_file(self, long_file): with patch( "sys.argv", [ @@ -411,7 +409,7 @@ def test_long_file(self, long_file: TextIO) -> None: call().print(), ] - def test_long_file_delete(self, long_file: TextIO) -> None: + def test_long_file_delete(self, long_file): with patch( "sys.argv", [ @@ -444,7 +442,7 @@ def test_long_file_delete(self, long_file: TextIO) -> None: call().print(), ] - def test_long_file_fix(self, long_file: TextIO) -> None: + def test_long_file_fix(self, long_file): with patch( "sys.argv", [ @@ -483,7 +481,7 @@ def test_long_file_fix(self, long_file: TextIO) -> None: call().print(), ] - def test_long_file_delete_fix(self, long_file: TextIO) -> None: + def test_long_file_delete_fix(self, long_file): with patch( "sys.argv", [ @@ -511,7 +509,7 @@ def test_long_file_delete_fix(self, long_file: TextIO) -> None: call().print(), ] - def test_bracket_file(self, bracket_file: TextIO) -> None: + def test_bracket_file(self, bracket_file): with patch( "sys.argv", [ diff --git a/test/rapids_pre_commit_hooks/test_pyproject_license.py b/test/rapids_pre_commit_hooks/test_pyproject_license.py index fa24ac9..195224e 100644 --- a/test/rapids_pre_commit_hooks/test_pyproject_license.py +++ b/test/rapids_pre_commit_hooks/test_pyproject_license.py @@ -21,8 +21,6 @@ from rapids_pre_commit_hooks import pyproject_license from rapids_pre_commit_hooks.lint import Linter -_LocType = tuple[int, int] - @pytest.mark.parametrize( ["key", "append", "loc"], @@ -54,7 +52,7 @@ ), ], ) -def test_find_value_location(key: tuple[str, ...], append: bool, loc: _LocType) -> None: +def test_find_value_location(key, append, loc): CONTENT = dedent( """\ [table] @@ -182,12 +180,12 @@ def test_find_value_location(key: tuple[str, ...], append: bool, loc: _LocType) ], ) def test_check_pyproject_license( - document: str, - loc: _LocType, - message: str, - replacement_loc: _LocType, - replacement_text: str, -) -> None: + document, + loc, + message, + replacement_loc, + replacement_text, +): linter = Linter("pyproject.toml", document) pyproject_license.check_pyproject_license(linter, Mock()) diff --git a/test/rapids_pre_commit_hooks/test_shell.py b/test/rapids_pre_commit_hooks/test_shell.py index 50dbec3..f5ddbf2 100644 --- a/test/rapids_pre_commit_hooks/test_shell.py +++ b/test/rapids_pre_commit_hooks/test_shell.py @@ -18,7 +18,7 @@ from rapids_pre_commit_hooks.shell.verify_conda_yes import VerifyCondaYesVisitor -def run_shell_linter(content: str, cls: type) -> Linter: +def run_shell_linter(content, cls): linter = Linter("test.sh", content) visitor = cls(linter, None) parts = bashlex.parse(content) @@ -27,7 +27,7 @@ def run_shell_linter(content: str, cls: type) -> Linter: return linter -def test_verify_conda_yes() -> None: +def test_verify_conda_yes(): CONTENT = r""" conda install -y pkg1 conda install --yes pkg2 pkg3 From 21ca4e080fd928b5d189f792543c73f8d98b4840 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 26 Aug 2024 14:20:47 -0400 Subject: [PATCH 18/20] Remove more type hints from tests --- test/test_pre_commit.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/test/test_pre_commit.py b/test/test_pre_commit.py index f58dddc..f0b651d 100644 --- a/test/test_pre_commit.py +++ b/test/test_pre_commit.py @@ -19,13 +19,11 @@ import subprocess import sys from functools import cache -from typing import Generator, Optional, Union import git import pytest import yaml from packaging.version import Version -from rapids_metadata.metadata import RAPIDSMetadata from rapids_metadata.remote import fetch_latest REPO_DIR = os.path.join(os.path.dirname(__file__), "..") @@ -35,12 +33,12 @@ @cache -def all_metadata() -> RAPIDSMetadata: +def all_metadata(): return fetch_latest() @contextlib.contextmanager -def set_cwd(cwd: Union[str, os.PathLike[str]]) -> Generator: +def set_cwd(cwd): old_cwd = os.getcwd() os.chdir(cwd) try: @@ -50,7 +48,7 @@ def set_cwd(cwd: Union[str, os.PathLike[str]]) -> Generator: @pytest.fixture -def git_repo(tmp_path: str) -> git.Repo: +def git_repo(tmp_path): repo = git.Repo.init(tmp_path) with repo.config_writer() as w: w.set_value("user", "name", "RAPIDS Test Fixtures") @@ -58,12 +56,8 @@ def git_repo(tmp_path: str) -> git.Repo: return repo -def run_pre_commit( - git_repo: git.Repo, hook_name: str, expected_status: str, exc: Optional[type] -) -> None: - assert git_repo.working_tree_dir is not None - - def list_files(top: str) -> Generator[str, None, None]: +def run_pre_commit(git_repo, hook_name, expected_status, exc): + def list_files(top): for dirpath, _, filenames in os.walk(top): for filename in filenames: yield filename if top == dirpath else os.path.join( @@ -110,7 +104,7 @@ def list_files(top: str) -> Generator[str, None, None]: "hook_name", ALL_HOOKS, ) -def test_pre_commit_pass(git_repo: git.Repo, hook_name: str) -> None: +def test_pre_commit_pass(git_repo, hook_name): run_pre_commit(git_repo, hook_name, "pass", None) @@ -118,5 +112,5 @@ def test_pre_commit_pass(git_repo: git.Repo, hook_name: str) -> None: "hook_name", ALL_HOOKS, ) -def test_pre_commit_fail(git_repo: git.Repo, hook_name: str) -> None: +def test_pre_commit_fail(git_repo, hook_name): run_pre_commit(git_repo, hook_name, "fail", subprocess.CalledProcessError) From 0635d1783b6f29249dd6b926670325032aa8860f Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 26 Aug 2024 14:21:37 -0400 Subject: [PATCH 19/20] Remove type: ignore --- test/test_pre_commit.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/test_pre_commit.py b/test/test_pre_commit.py index f0b651d..4579aa0 100644 --- a/test/test_pre_commit.py +++ b/test/test_pre_commit.py @@ -80,9 +80,7 @@ def list_files(top): branch_dir = os.path.join(example_dir, "branch") if os.path.exists(branch_dir): - git_repo.head.reference = git_repo.create_head( # type: ignore[misc] - "branch", git_repo.head.commit - ) + git_repo.head.reference = git_repo.create_head("branch", git_repo.head.commit) git_repo.index.remove(list(list_files(master_dir)), working_tree=True) shutil.copytree(branch_dir, git_repo.working_tree_dir, dirs_exist_ok=True) git_repo.index.add(list(list_files(branch_dir))) From 807ac93c0757022b54f3e87e221239f0cd82b29c Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 26 Aug 2024 14:26:43 -0400 Subject: [PATCH 20/20] More quotes --- src/rapids_pre_commit_hooks/alpha_spec.py | 2 +- src/rapids_pre_commit_hooks/copyright.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rapids_pre_commit_hooks/alpha_spec.py b/src/rapids_pre_commit_hooks/alpha_spec.py index 9d1e7d0..dd3d982 100644 --- a/src/rapids_pre_commit_hooks/alpha_spec.py +++ b/src/rapids_pre_commit_hooks/alpha_spec.py @@ -270,7 +270,7 @@ def __init__(self, stream) -> None: super().__init__(stream) self.document_anchors: list[dict[str, yaml.Node]] = [] - def compose_document(self) -> yaml.Node: + def compose_document(self) -> "yaml.Node": # Drop the DOCUMENT-START event. self.get_event() diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index e282869..9c46ad4 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -202,7 +202,7 @@ def get_target_branch_upstream_commit( key=lambda commit: commit.committed_datetime, ) - def try_get_ref(remote: git.Remote) -> Optional[git.Reference]: + def try_get_ref(remote: "git.Remote") -> Optional["git.Reference"]: try: return remote.refs[target_branch_name] except IndexError: @@ -242,7 +242,7 @@ def get_changed_files( for filename in filenames } - changed_files: dict[Union[str, os.PathLike[str]], Optional[git.Blob]] = { + changed_files: dict[Union[str, os.PathLike[str]], Optional["git.Blob"]] = { f: None for f in repo.untracked_files } target_branch_upstream_commit = get_target_branch_upstream_commit(repo, args)