Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to all code #50

Merged
merged 20 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2481170
Add type hints to rapids_pre_commit_hooks.lint
KyleFromNVIDIA Aug 23, 2024
6336dca
Add type hints to rapids_pre_commit_hooks.alpha_spec
KyleFromNVIDIA Aug 23, 2024
9066273
Add type hints to rapids_pre_commit_hooks.copyright
KyleFromNVIDIA Aug 23, 2024
c471b96
Add type hints to rapids_pre_commit_hooks.pyproject_license
KyleFromNVIDIA Aug 23, 2024
f1447e5
Add type hints to rapids_pre_commit_hooks.shell
KyleFromNVIDIA Aug 23, 2024
dff9f46
Add type hints to test/rapids_pre_commit_hooks/test_alpha_spec.py
KyleFromNVIDIA Aug 23, 2024
04e8282
Add type hints to test/rapids_pre_commit_hooks/test_copyright.py
KyleFromNVIDIA Aug 23, 2024
3f96086
Add type hints to test/rapids_pre_commit_hooks/test_lint.py
KyleFromNVIDIA Aug 23, 2024
1af8f78
Add type hints to test/rapids_pre_commit_hooks/test_pyproject_license.py
KyleFromNVIDIA Aug 23, 2024
5c9aa5b
Add type hints to test/rapids_pre_commit_hooks/test_shell.py
KyleFromNVIDIA Aug 23, 2024
556322e
Add type hints to test/test_pre_commit.py
KyleFromNVIDIA Aug 23, 2024
93000b1
Add more type hints, refactor test
KyleFromNVIDIA Aug 26, 2024
8917b35
Type assert
KyleFromNVIDIA Aug 26, 2024
62f24b6
Make type: ignore more specific
KyleFromNVIDIA Aug 26, 2024
f10af82
Add quote to third-party type hints
KyleFromNVIDIA Aug 26, 2024
7187888
Fix type assertion
KyleFromNVIDIA Aug 26, 2024
293a9a8
Remove type hints from tests
KyleFromNVIDIA Aug 26, 2024
21ca4e0
Remove more type hints from tests
KyleFromNVIDIA Aug 26, 2024
0635d17
Remove type: ignore
KyleFromNVIDIA Aug 26, 2024
807ac93
More quotes
KyleFromNVIDIA Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 75 additions & 25 deletions src/rapids_pre_commit_hooks/alpha_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,27 @@

import yaml
from packaging.requirements import InvalidRequirement, Requirement
from rapids_metadata.metadata import RAPIDSVersion
from rapids_metadata.metadata import RAPIDSMetadata, RAPIDSVersion
from rapids_metadata.remote import fetch_latest

from .lint import LintMain
from .lint import Linter, LintMain

ALPHA_SPECIFIER = ">=0.0.0a0"
ALPHA_SPECIFIER: str = ">=0.0.0a0"

ALPHA_SPEC_OUTPUT_TYPES = {
ALPHA_SPEC_OUTPUT_TYPES: set[str] = {
"pyproject",
"requirements",
}

CUDA_SUFFIX_REGEX = re.compile(r"^(?P<package>.*)-cu[0-9]{2}$")
CUDA_SUFFIX_REGEX: re.Pattern = re.compile(r"^(?P<package>.*)-cu[0-9]{2}$")


@cache
def all_metadata():
def all_metadata() -> RAPIDSMetadata:
return fetch_latest()


def node_has_type(node, tag_type):
def node_has_type(node: yaml.Node, tag_type: str):
return node.tag == f"tag:yaml.org,2002:{tag_type}"


Expand All @@ -60,7 +60,9 @@ def strip_cuda_suffix(args: argparse.Namespace, name: str) -> str:
return name


def check_and_mark_anchor(anchors, used_anchors, node):
def check_and_mark_anchor(
anchors: dict[str, yaml.Node], used_anchors: set[str], node: yaml.Node
):
for key, value in anchors.items():
if value == node:
anchor = key
Expand All @@ -74,16 +76,26 @@ def check_and_mark_anchor(anchors, used_anchors, node):
return True, anchor


def check_package_spec(linter, args, anchors, used_anchors, node):
def check_package_spec(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
used_anchors: set[str],
node: yaml.Node,
):
@total_ordering
class SpecPriority:
def __init__(self, spec):
self.spec = spec
def __init__(self, spec: str):
self.spec: str = spec

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, SpecPriority):
return False
return self.spec == other.spec

def __lt__(self, other):
def __lt__(self, other: object) -> bool:
if not isinstance(other, SpecPriority):
return False
KyleFromNVIDIA marked this conversation as resolved.
Show resolved Hide resolved
if self.spec == other.spec:
return False
if self.spec == ALPHA_SPECIFIER:
Expand All @@ -92,10 +104,10 @@ def __lt__(self, other):
return True
return self.sort_str() < other.sort_str()

def sort_str(self):
def sort_str(self) -> str:
return "".join(c for c in self.spec if c not in "<>=")

def create_specifier_string(specifiers):
def create_specifier_string(specifiers: set[str]) -> str:
return ",".join(sorted(specifiers, key=SpecPriority))

if node_has_type(node, "str"):
Expand Down Expand Up @@ -140,15 +152,27 @@ def create_specifier_string(specifiers):
)


def check_packages(linter, args, anchors, used_anchors, node):
def check_packages(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
used_anchors: set[str],
node: yaml.Node,
):
if node_has_type(node, "seq"):
descend, _ = check_and_mark_anchor(anchors, used_anchors, node)
if descend:
for package_spec in node.value:
check_package_spec(linter, args, anchors, used_anchors, package_spec)


def check_common(linter, args, anchors, used_anchors, node):
def check_common(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
used_anchors: set[str],
node: yaml.Node,
):
if node_has_type(node, "seq"):
for dependency_set in node.value:
if node_has_type(dependency_set, "map"):
Expand All @@ -162,7 +186,13 @@ def check_common(linter, args, anchors, used_anchors, node):
)


def check_matrices(linter, args, anchors, used_anchors, node):
def check_matrices(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
used_anchors: set[str],
node: yaml.Node,
):
if node_has_type(node, "seq"):
for item in node.value:
if node_has_type(item, "map"):
Expand All @@ -176,7 +206,13 @@ def check_matrices(linter, args, anchors, used_anchors, node):
)


def check_specific(linter, args, anchors, used_anchors, node):
def check_specific(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
used_anchors: set[str],
node: yaml.Node,
):
if node_has_type(node, "seq"):
for matrix_matcher in node.value:
if node_has_type(matrix_matcher, "map"):
Expand All @@ -190,7 +226,13 @@ def check_specific(linter, args, anchors, used_anchors, node):
)


def check_dependencies(linter, args, anchors, used_anchors, node):
def check_dependencies(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
used_anchors: set[str],
node: yaml.Node,
):
if node_has_type(node, "map"):
for _, dependencies_value in node.value:
if node_has_type(dependencies_value, "map"):
Expand All @@ -206,7 +248,13 @@ def check_dependencies(linter, args, anchors, used_anchors, node):
)


def check_root(linter, args, anchors, used_anchors, node):
def check_root(
linter: Linter,
args: argparse.Namespace,
anchors: dict[str, yaml.Node],
used_anchors: set[str],
node: yaml.Node,
):
if node_has_type(node, "map"):
for root_key, root_value in node.value:
if node_has_type(root_key, "str") and root_key.value == "dependencies":
Expand All @@ -221,27 +269,29 @@ class AnchorPreservingLoader(yaml.SafeLoader):

def __init__(self, stream):
super().__init__(stream)
self.document_anchors = []
self.document_anchors: list[dict[str, yaml.Node]] = []

def compose_document(self):
def compose_document(self) -> yaml.Node:
# Drop the DOCUMENT-START event.
self.get_event()

# Compose the root node.
node = self.compose_node(None, None)
node = self.compose_node(None, None) # type: ignore

# Drop the DOCUMENT-END event.
self.get_event()

self.document_anchors.append(self.anchors)
self.anchors = {}
assert node is not None
return node


def check_alpha_spec(linter, args):
def check_alpha_spec(linter: Linter, args: argparse.Namespace):
loader = AnchorPreservingLoader(linter.content)
try:
root = loader.get_single_node()
assert root is not None
finally:
loader.dispose()
check_root(linter, args, loader.document_anchors[0], set(), root)
Expand Down
50 changes: 31 additions & 19 deletions src/rapids_pre_commit_hooks/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import datetime
import functools
import os
import re
import warnings
from typing import Callable, Optional, Union

import git

from .lint import LintMain
from .lint import Linter, LintMain

COPYRIGHT_RE = re.compile(
COPYRIGHT_RE: re.Pattern = re.compile(
r"Copyright *(?:\(c\))? *(?P<years>(?P<first_year>\d{4})(-(?P<last_year>\d{4}))?),?"
r" *NVIDIA C(?:ORPORATION|orporation)"
)
BRANCH_RE = re.compile(r"^branch-(?P<major>[0-9]+)\.(?P<minor>[0-9]+)$")
COPYRIGHT_REPLACEMENT = "Copyright (c) {first_year}-{last_year}, NVIDIA CORPORATION"
BRANCH_RE: re.Pattern = re.compile(r"^branch-(?P<major>[0-9]+)\.(?P<minor>[0-9]+)$")
COPYRIGHT_REPLACEMENT: str = (
"Copyright (c) {first_year}-{last_year}, NVIDIA CORPORATION"
)


class NoTargetBranchWarning(RuntimeWarning):
Expand All @@ -38,14 +42,14 @@ class ConflictingFilesWarning(RuntimeWarning):
pass


def match_copyright(content):
def match_copyright(content: str) -> list[re.Match]:
return list(COPYRIGHT_RE.finditer(content))


def strip_copyright(content, copyright_matches):
def strip_copyright(content: str, copyright_matches: list[re.Match]) -> list[str]:
lines = []

def append_stripped(start, item):
def append_stripped(start: int, item: re.Match):
lines.append(content[start : item.start()])
return item.end()

Expand All @@ -54,7 +58,7 @@ def append_stripped(start, item):
return lines


def apply_copyright_revert(linter, old_match, new_match):
def apply_copyright_revert(linter: Linter, old_match: re.Match, new_match: re.Match):
if old_match.group("years") == new_match.group("years"):
warning_pos = new_match.span()
else:
Expand All @@ -65,7 +69,7 @@ def apply_copyright_revert(linter, old_match, new_match):
).add_replacement(new_match.span(), old_match.group())


def apply_copyright_update(linter, match, year):
def apply_copyright_update(linter: Linter, match: re.Match, year: int):
linter.add_warning(match.span("years"), "copyright is out of date").add_replacement(
match.span(),
COPYRIGHT_REPLACEMENT.format(
Expand All @@ -75,7 +79,7 @@ def apply_copyright_update(linter, match, year):
)


def apply_copyright_check(linter, old_content):
def apply_copyright_check(linter: Linter, old_content: str):
if linter.content != old_content:
current_year = datetime.datetime.now().year
new_copyright_matches = match_copyright(linter.content)
Expand All @@ -102,7 +106,7 @@ def apply_copyright_check(linter, old_content):
linter.add_warning((0, 0), "no copyright notice found")


def get_target_branch(repo, args):
def get_target_branch(repo: git.Repo, args: argparse.Namespace) -> Optional[str]:
"""Determine which branch is the "target" branch.

The target branch is determined in the following order:
Expand Down Expand Up @@ -168,7 +172,9 @@ def get_target_branch(repo, args):
return None


def get_target_branch_upstream_commit(repo, args):
def get_target_branch_upstream_commit(
repo: git.Repo, args: argparse.Namespace
) -> Optional[git.Commit]:
# If no target branch can be determined, use HEAD if it exists
target_branch_name = get_target_branch(repo, args)
if target_branch_name is None:
Expand All @@ -194,7 +200,7 @@ def get_target_branch_upstream_commit(repo, args):
key=lambda commit: commit.committed_datetime,
)

def try_get_ref(remote):
def try_get_ref(remote: git.Remote) -> Optional[git.Reference]:
try:
return remote.refs[target_branch_name]
except IndexError:
Expand Down Expand Up @@ -222,7 +228,9 @@ def try_get_ref(remote):
return None


def get_changed_files(args):
def get_changed_files(
args: argparse.Namespace,
) -> dict[Union[str, os.PathLike[str]], Optional[git.Blob]]:
try:
repo = git.Repo()
except git.InvalidGitRepositoryError:
Expand All @@ -232,7 +240,9 @@ def get_changed_files(args):
for filename in filenames
}

changed_files = {f: None for f in repo.untracked_files}
changed_files: dict[Union[str, os.PathLike[str]], Optional[git.Blob]] = {
f: None for f in repo.untracked_files
}
target_branch_upstream_commit = get_target_branch_upstream_commit(repo, args)
if target_branch_upstream_commit is None:
changed_files.update({blob.path: None for _, blob in repo.index.iter_blobs()})
Expand All @@ -256,14 +266,14 @@ def get_changed_files(args):
return changed_files


def normalize_git_filename(filename):
def normalize_git_filename(filename: Union[str, os.PathLike[str]]):
relpath = os.path.relpath(filename)
if re.search(r"^\.\.(/|$)", relpath):
return None
return relpath


def find_blob(tree, filename):
def find_blob(tree: git.Tree, filename: Union[str, os.PathLike[str]]):
d1, d2 = os.path.split(filename)
split = [d2]
while d1:
Expand All @@ -283,10 +293,12 @@ def find_blob(tree, filename):
return None


def check_copyright(args):
def check_copyright(
args: argparse.Namespace,
) -> Callable[[Linter, argparse.Namespace], None]:
changed_files = get_changed_files(args)

def the_check(linter, args):
def the_check(linter: Linter, args: argparse.Namespace):
if not (git_filename := normalize_git_filename(linter.filename)):
warnings.warn(
f'File "{linter.filename}" is outside of current directory. Not '
Expand Down
Loading