Skip to content

Commit

Permalink
Pass more change information to apply_copyright_check
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleFromNVIDIA committed Aug 26, 2024
1 parent 92adb5e commit 8028f1b
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 139 deletions.
40 changes: 24 additions & 16 deletions src/rapids_pre_commit_hooks/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ def apply_copyright_update(linter: Linter, match: re.Match, year: int) -> None:
)


def apply_copyright_check(linter: Linter, old_content: Optional[str]) -> None:
def apply_copyright_check(
linter: Linter,
change_type: str,
old_filename: Optional[str],
old_content: Optional[str],
) -> None:
if linter.content != old_content:
current_year = datetime.datetime.now().year
new_copyright_matches = match_copyright(linter.content)
Expand Down Expand Up @@ -232,22 +237,24 @@ def try_get_ref(remote: "git.Remote") -> Optional["git.Reference"]:

def get_changed_files(
args: argparse.Namespace,
) -> dict[Union[str, os.PathLike[str]], Optional["git.Blob"]]:
) -> dict[Union[str, os.PathLike[str]], tuple[str, Optional["git.Blob"]]]:
try:
repo = git.Repo()
except git.InvalidGitRepositoryError:
return {
os.path.relpath(os.path.join(dirpath, filename), "."): None
os.path.relpath(os.path.join(dirpath, filename), "."): ("A", None)
for dirpath, dirnames, filenames in os.walk(".")
for filename in filenames
}

changed_files: dict[Union[str, os.PathLike[str]], Optional["git.Blob"]] = {
f: None for f in repo.untracked_files
}
changed_files: dict[
Union[str, os.PathLike[str]], tuple[str, Optional["git.Blob"]]
] = {f: ("A", None) for f in repo.untracked_files}
target_branch_upstream_commit = get_target_branch_upstream_commit(repo, args)
if target_branch_upstream_commit is None:
changed_files.update({blob.path: None for _, blob in repo.index.iter_blobs()})
changed_files.update(
{blob.path: ("A", None) for _, blob in repo.index.iter_blobs()}
)
return changed_files

for merge_base in repo.merge_base(
Expand All @@ -261,9 +268,9 @@ def get_changed_files(
)
for diff in diffs:
if diff.change_type == "A":
changed_files[diff.b_path] = None
changed_files[diff.b_path] = (diff.change_type, None)
elif diff.change_type != "D":
changed_files[diff.b_path] = diff.a_blob
changed_files[diff.b_path] = (diff.change_type, diff.a_blob)

return changed_files

Expand Down Expand Up @@ -312,16 +319,17 @@ def the_check(linter: Linter, args: argparse.Namespace):
return

try:
changed_file = changed_files[git_filename]
change_type, changed_file = changed_files[git_filename]
except KeyError:
return

old_content = (
changed_file.data_stream.read().decode()
if changed_file is not None
else None
)
apply_copyright_check(linter, old_content)
if changed_file is None:
old_filename = None
old_content = None
else:
old_filename = changed_file.name
old_content = changed_file.data_stream.read().decode()
apply_copyright_check(linter, change_type, old_filename, old_content)

return the_check

Expand Down
4 changes: 2 additions & 2 deletions src/rapids_pre_commit_hooks/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class LintWarning:
pos: _PosType
msg: str
replacements: list[Replacement] = dataclasses.field(
default_factory=list, init=False
default_factory=list, kw_only=True
)
notes: list[Note] = dataclasses.field(default_factory=list, init=False)
notes: list[Note] = dataclasses.field(default_factory=list, kw_only=True)

def add_replacement(self, pos: _PosType, newtext: str) -> None:
self.replacements.append(Replacement(pos, newtext))
Expand Down
Loading

0 comments on commit 8028f1b

Please sign in to comment.