From 8028f1b3ec4faff26e88133a69bf11e09b5f5d5f Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 26 Aug 2024 15:26:00 -0400 Subject: [PATCH] Pass more change information to apply_copyright_check --- src/rapids_pre_commit_hooks/copyright.py | 40 +- src/rapids_pre_commit_hooks/lint.py | 4 +- .../rapids_pre_commit_hooks/test_copyright.py | 362 ++++++++++++------ 3 files changed, 267 insertions(+), 139 deletions(-) diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 9c46ad4..d02c600 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -81,7 +81,12 @@ def apply_copyright_update(linter: Linter, match: re.Match, year: int) -> None: ) -def apply_copyright_check(linter: Linter, old_content: Optional[str]) -> None: +def apply_copyright_check( + linter: Linter, + change_type: str, + old_filename: Optional[str], + old_content: Optional[str], +) -> None: if linter.content != old_content: current_year = datetime.datetime.now().year new_copyright_matches = match_copyright(linter.content) @@ -232,22 +237,24 @@ def try_get_ref(remote: "git.Remote") -> Optional["git.Reference"]: def get_changed_files( args: argparse.Namespace, -) -> dict[Union[str, os.PathLike[str]], Optional["git.Blob"]]: +) -> dict[Union[str, os.PathLike[str]], tuple[str, Optional["git.Blob"]]]: try: repo = git.Repo() except git.InvalidGitRepositoryError: return { - os.path.relpath(os.path.join(dirpath, filename), "."): None + os.path.relpath(os.path.join(dirpath, filename), "."): ("A", None) for dirpath, dirnames, filenames in os.walk(".") for filename in filenames } - changed_files: dict[Union[str, os.PathLike[str]], Optional["git.Blob"]] = { - f: None for f in repo.untracked_files - } + changed_files: dict[ + Union[str, os.PathLike[str]], tuple[str, Optional["git.Blob"]] + ] = {f: ("A", None) for f in repo.untracked_files} target_branch_upstream_commit = get_target_branch_upstream_commit(repo, args) if target_branch_upstream_commit is None: - changed_files.update({blob.path: None for _, blob in repo.index.iter_blobs()}) + changed_files.update( + {blob.path: ("A", None) for _, blob in repo.index.iter_blobs()} + ) return changed_files for merge_base in repo.merge_base( @@ -261,9 +268,9 @@ def get_changed_files( ) for diff in diffs: if diff.change_type == "A": - changed_files[diff.b_path] = None + changed_files[diff.b_path] = (diff.change_type, None) elif diff.change_type != "D": - changed_files[diff.b_path] = diff.a_blob + changed_files[diff.b_path] = (diff.change_type, diff.a_blob) return changed_files @@ -312,16 +319,17 @@ def the_check(linter: Linter, args: argparse.Namespace): return try: - changed_file = changed_files[git_filename] + change_type, changed_file = changed_files[git_filename] except KeyError: return - old_content = ( - changed_file.data_stream.read().decode() - if changed_file is not None - else None - ) - apply_copyright_check(linter, old_content) + if changed_file is None: + old_filename = None + old_content = None + else: + old_filename = changed_file.name + old_content = changed_file.data_stream.read().decode() + apply_copyright_check(linter, change_type, old_filename, old_content) return the_check diff --git a/src/rapids_pre_commit_hooks/lint.py b/src/rapids_pre_commit_hooks/lint.py index 781cecb..3a3b033 100644 --- a/src/rapids_pre_commit_hooks/lint.py +++ b/src/rapids_pre_commit_hooks/lint.py @@ -65,9 +65,9 @@ class LintWarning: pos: _PosType msg: str replacements: list[Replacement] = dataclasses.field( - default_factory=list, init=False + default_factory=list, kw_only=True ) - notes: list[Note] = dataclasses.field(default_factory=list, init=False) + notes: list[Note] = dataclasses.field(default_factory=list, kw_only=True) def add_replacement(self, pos: _PosType, newtext: str) -> None: self.replacements.append(Replacement(pos, newtext)) diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 340c9cc..6044163 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -23,7 +23,7 @@ from freezegun import freeze_time from rapids_pre_commit_hooks import copyright -from rapids_pre_commit_hooks.lint import Linter +from rapids_pre_commit_hooks.lint import Linter, LintWarning, Replacement def test_match_copyright(): @@ -92,84 +92,188 @@ def test_strip_copyright(): assert stripped == ["No copyright here"] +@pytest.mark.parametrize( + [ + "change_type", + "old_filename", + "old_content", + "new_filename", + "new_content", + "warnings", + ], + [ + ( + "A", + None, + None, + "file.txt", + "No copyright notice", + [ + LintWarning((0, 0), "no copyright notice found"), + ], + ), + ( + "M", + "file.txt", + "No copyright notice", + "file.txt", + "No copyright notice", + [], + ), + ( + "M", + "file.txt", + dedent( + r""" + Copyright (c) 2021-2023 NVIDIA CORPORATION + Copyright (c) 2023 NVIDIA CORPORATION + Copyright (c) 2024 NVIDIA CORPORATION + Copyright (c) 2025 NVIDIA CORPORATION + This file has not been changed + """ + ), + "file.txt", + dedent( + r""" + Copyright (c) 2021-2023 NVIDIA CORPORATION + Copyright (c) 2023 NVIDIA CORPORATION + Copyright (c) 2024 NVIDIA CORPORATION + Copyright (c) 2025 NVIDIA CORPORATION + This file has not been changed + """ + ), + [], + ), + ( + "M", + "file.txt", + dedent( + r""" + Copyright (c) 2021-2023 NVIDIA CORPORATION + Copyright (c) 2023 NVIDIA CORPORATION + Copyright (c) 2024 NVIDIA CORPORATION + Copyright (c) 2025 NVIDIA CORPORATION + This file has not been changed + """ + ), + "file.txt", + dedent( + r""" + Copyright (c) 2021-2023 NVIDIA CORPORATION + Copyright (c) 2023 NVIDIA CORPORATION + Copyright (c) 2024 NVIDIA CORPORATION + Copyright (c) 2025 NVIDIA CORPORATION + This file has been changed + """ + ), + [ + LintWarning( + (15, 24), + "copyright is out of date", + replacements=[ + Replacement( + (1, 43), "Copyright (c) 2021-2024, NVIDIA CORPORATION" + ), + ], + ), + LintWarning( + (58, 62), + "copyright is out of date", + replacements=[ + Replacement( + (44, 81), "Copyright (c) 2023-2024, NVIDIA CORPORATION" + ), + ], + ), + ], + ), + ( + "A", + None, + None, + "file.txt", + dedent( + r""" + Copyright (c) 2021-2023 NVIDIA CORPORATION + Copyright (c) 2023 NVIDIA CORPORATION + Copyright (c) 2024 NVIDIA CORPORATION + Copyright (c) 2025 NVIDIA CORPORATION + This file has been changed + """ + ), + [ + LintWarning( + (15, 24), + "copyright is out of date", + replacements=[ + Replacement( + (1, 43), "Copyright (c) 2021-2024, NVIDIA CORPORATION" + ), + ], + ), + LintWarning( + (58, 62), + "copyright is out of date", + replacements=[ + Replacement( + (44, 81), "Copyright (c) 2023-2024, NVIDIA CORPORATION" + ), + ], + ), + ], + ), + ( + "M", + "file.txt", + dedent( + r""" + Copyright (c) 2021-2023 NVIDIA CORPORATION + Copyright (c) 2023 NVIDIA CORPORATION + Copyright (c) 2024 NVIDIA CORPORATION + Copyright (c) 2025 NVIDIA CORPORATION + This file has not been changed + """ + ), + "file.txt", + dedent( + r""" + Copyright (c) 2021-2024 NVIDIA CORPORATION + Copyright (c) 2023 NVIDIA CORPORATION + Copyright (c) 2024 NVIDIA CORPORATION + Copyright (c) 2025 NVIDIA Corporation + This file has not been changed + """ + ), + [ + LintWarning( + (15, 24), + "copyright is not out of date and should not be updated", + replacements=[ + Replacement( + (1, 43), "Copyright (c) 2021-2023 NVIDIA CORPORATION" + ), + ], + ), + LintWarning( + (120, 157), + "copyright is not out of date and should not be updated", + replacements=[ + Replacement( + (120, 157), "Copyright (c) 2025 NVIDIA CORPORATION" + ), + ], + ), + ], + ), + ], +) @freeze_time("2024-01-18") -def test_apply_copyright_check(): - def run_apply_copyright_check(old_content, new_content): - linter = Linter("file.txt", new_content) - copyright.apply_copyright_check(linter, old_content) - return linter - - expected_linter = Linter("file.txt", "No copyright notice") - expected_linter.add_warning((0, 0), "no copyright notice found") - - linter = run_apply_copyright_check(None, "No copyright notice") - assert linter.warnings == expected_linter.warnings - - linter = run_apply_copyright_check("No copyright notice", "No copyright notice") - assert linter.warnings == [] - - OLD_CONTENT = dedent( - r""" - Copyright (c) 2021-2023 NVIDIA CORPORATION - Copyright (c) 2023 NVIDIA CORPORATION - Copyright (c) 2024 NVIDIA CORPORATION - Copyright (c) 2025 NVIDIA CORPORATION - This file has not been changed - """ - ) - linter = run_apply_copyright_check(OLD_CONTENT, OLD_CONTENT) - assert linter.warnings == [] - - NEW_CONTENT = dedent( - r""" - Copyright (c) 2021-2023 NVIDIA CORPORATION - Copyright (c) 2023 NVIDIA CORPORATION - Copyright (c) 2024 NVIDIA CORPORATION - Copyright (c) 2025 NVIDIA CORPORATION - This file has been changed - """ - ) - expected_linter = Linter("file.txt", NEW_CONTENT) - expected_linter.add_warning((15, 24), "copyright is out of date").add_replacement( - (1, 43), "Copyright (c) 2021-2024, NVIDIA CORPORATION" - ) - expected_linter.add_warning((58, 62), "copyright is out of date").add_replacement( - (44, 81), "Copyright (c) 2023-2024, NVIDIA CORPORATION" - ) - - linter = run_apply_copyright_check(OLD_CONTENT, NEW_CONTENT) - assert linter.warnings == expected_linter.warnings - - expected_linter = Linter("file.txt", NEW_CONTENT) - expected_linter.add_warning((15, 24), "copyright is out of date").add_replacement( - (1, 43), "Copyright (c) 2021-2024, NVIDIA CORPORATION" - ) - expected_linter.add_warning((58, 62), "copyright is out of date").add_replacement( - (44, 81), "Copyright (c) 2023-2024, NVIDIA CORPORATION" - ) - - linter = run_apply_copyright_check(None, NEW_CONTENT) - assert linter.warnings == expected_linter.warnings - - NEW_CONTENT = dedent( - r""" - Copyright (c) 2021-2024 NVIDIA CORPORATION - Copyright (c) 2023 NVIDIA CORPORATION - Copyright (c) 2024 NVIDIA CORPORATION - Copyright (c) 2025 NVIDIA Corporation - This file has not been changed - """ - ) - expected_linter = Linter("file.txt", NEW_CONTENT) - expected_linter.add_warning( - (15, 24), "copyright is not out of date and should not be updated" - ).add_replacement((1, 43), "Copyright (c) 2021-2023 NVIDIA CORPORATION") - expected_linter.add_warning( - (120, 157), "copyright is not out of date and should not be updated" - ).add_replacement((120, 157), "Copyright (c) 2025 NVIDIA CORPORATION") - - linter = run_apply_copyright_check(OLD_CONTENT, NEW_CONTENT) - assert linter.warnings == expected_linter.warnings +def test_apply_copyright_check( + change_type, old_filename, old_content, new_filename, new_content, warnings +): + linter = Linter(new_filename, new_content) + copyright.apply_copyright_check(linter, change_type, old_filename, old_content) + assert linter.warnings == warnings @pytest.fixture @@ -537,8 +641,8 @@ def mock_os_walk(top): with open(os.path.join(non_git_dir, "subdir1", "subdir2", "sub.txt"), "w") as f: f.write("Subdir file\n") assert copyright.get_changed_files(Mock()) == { - "top.txt": None, - "subdir1/subdir2/sub.txt": None, + "top.txt": ("A", None), + "subdir1/subdir2/sub.txt": ("A", None), } def fn(filename): @@ -582,16 +686,16 @@ def file_contents(verbed): Mock(return_value=None), ): assert copyright.get_changed_files(Mock()) == { - "untouched.txt": None, - "copied.txt": None, - "modified_and_copied.txt": None, - "copied_and_modified.txt": None, - "deleted.txt": None, - "renamed.txt": None, - "modified_and_renamed.txt": None, - "modified.txt": None, - "chmodded.txt": None, - "untracked.txt": None, + "untouched.txt": ("A", None), + "copied.txt": ("A", None), + "modified_and_copied.txt": ("A", None), + "copied_and_modified.txt": ("A", None), + "deleted.txt": ("A", None), + "renamed.txt": ("A", None), + "modified_and_renamed.txt": ("A", None), + "modified.txt": ("A", None), + "chmodded.txt": ("A", None), + "untracked.txt": ("A", None), } git_repo.index.commit("Initial commit") @@ -655,20 +759,20 @@ def file_contents(verbed): # Truly need to be checked changed = { - "added.txt": None, - "untracked.txt": None, - "modified_and_renamed_2.txt": "modified_and_renamed.txt", - "modified.txt": "modified.txt", - "copied_and_modified_2.txt": "copied_and_modified.txt", - "modified_and_copied.txt": "modified_and_copied.txt", + "added.txt": ("A", None), + "untracked.txt": ("A", None), + "modified_and_renamed_2.txt": ("R", "modified_and_renamed.txt"), + "modified.txt": ("M", "modified.txt"), + "copied_and_modified_2.txt": ("C", "copied_and_modified.txt"), + "modified_and_copied.txt": ("M", "modified_and_copied.txt"), } # Superfluous, but harmless because the content is identical superfluous = { - "chmodded.txt": "chmodded.txt", - "modified_and_copied_2.txt": "modified_and_copied.txt", - "copied_2.txt": "copied.txt", - "renamed_2.txt": "renamed.txt", + "chmodded.txt": ("M", "chmodded.txt"), + "modified_and_copied_2.txt": ("C", "modified_and_copied.txt"), + "copied_2.txt": ("C", "copied.txt"), + "renamed_2.txt": ("R", "renamed.txt"), } with patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)), patch( @@ -677,25 +781,25 @@ def file_contents(verbed): ): changed_files = copyright.get_changed_files(Mock()) assert { - path: old_blob.path if old_blob else None + path: (old_blob[0], old_blob[1].path if old_blob[1] else None) for path, old_blob in changed_files.items() } == changed | superfluous for new, old in changed.items(): - if old: + if old[1]: with open(fn(new), "rb") as f: new_contents = f.read() - old_contents = old_files[old].data_stream.read() + old_contents = old_files[old[1]].data_stream.read() assert new_contents != old_contents - assert changed_files[new].data_stream.read() == old_contents + assert changed_files[new][1].data_stream.read() == old_contents for new, old in superfluous.items(): - if old: + if old[1]: with open(fn(new), "rb") as f: new_contents = f.read() - old_contents = old_files[old].data_stream.read() + old_contents = old_files[old[1]].data_stream.read() assert new_contents == old_contents - assert changed_files[new].data_stream.read() == old_contents + assert changed_files[new][1].data_stream.read() == old_contents def test_get_changed_files_multiple_merge_bases(git_repo): @@ -768,12 +872,12 @@ def write_file(filename, contents): ): changed_files = copyright.get_changed_files(Mock()) assert { - path: old_blob.path if old_blob else None + path: (old_blob[0], old_blob[1].path if old_blob[1] else None) for path, old_blob in changed_files.items() } == { - "file1.txt": "file1.txt", - "file2.txt": "file2.txt", - "file3.txt": "file3.txt", + "file1.txt": ("M", "file1.txt"), + "file2.txt": ("M", "file2.txt"), + "file3.txt": ("M", "file3.txt"), } @@ -914,22 +1018,28 @@ def mock_apply_copyright_check(): linter = Linter("file5.txt", file_contents(2)) with mock_apply_copyright_check() as apply_copyright_check: copyright_checker(linter, mock_args) - apply_copyright_check.assert_called_once_with(linter, file_contents(2)) + apply_copyright_check.assert_called_once_with( + linter, "R", "file2.txt", file_contents(2) + ) linter = Linter("file3.txt", file_contents_modified(3)) with mock_apply_copyright_check() as apply_copyright_check: copyright_checker(linter, mock_args) - apply_copyright_check.assert_called_once_with(linter, file_contents(3)) + apply_copyright_check.assert_called_once_with( + linter, "M", "file3.txt", file_contents(3) + ) linter = Linter("file4.txt", file_contents_modified(4)) with mock_apply_copyright_check() as apply_copyright_check: copyright_checker(linter, mock_args) - apply_copyright_check.assert_called_once_with(linter, file_contents(4)) + apply_copyright_check.assert_called_once_with( + linter, "M", "file4.txt", file_contents(4) + ) linter = Linter("file6.txt", file_contents(6)) with mock_apply_copyright_check() as apply_copyright_check: copyright_checker(linter, mock_args) - apply_copyright_check.assert_called_once_with(linter, None) + apply_copyright_check.assert_called_once_with(linter, "A", None, None) ############################# # branch-2 is target branch @@ -943,12 +1053,16 @@ def mock_apply_copyright_check(): linter = Linter("file1.txt", file_contents_modified(1)) with mock_apply_copyright_check() as apply_copyright_check: copyright_checker(linter, mock_args) - apply_copyright_check.assert_called_once_with(linter, file_contents(1)) + apply_copyright_check.assert_called_once_with( + linter, "M", "file1.txt", file_contents(1) + ) linter = Linter("./file1.txt", file_contents_modified(1)) with mock_apply_copyright_check() as apply_copyright_check: copyright_checker(linter, mock_args) - apply_copyright_check.assert_called_once_with(linter, file_contents(1)) + apply_copyright_check.assert_called_once_with( + linter, "M", "file1.txt", file_contents(1) + ) linter = Linter("../file1.txt", file_contents_modified(1)) with mock_apply_copyright_check() as apply_copyright_check: @@ -963,19 +1077,25 @@ def mock_apply_copyright_check(): linter = Linter("file5.txt", file_contents(2)) with mock_apply_copyright_check() as apply_copyright_check: copyright_checker(linter, mock_args) - apply_copyright_check.assert_called_once_with(linter, file_contents(2)) + apply_copyright_check.assert_called_once_with( + linter, "R", "file2.txt", file_contents(2) + ) linter = Linter("file3.txt", file_contents_modified(3)) with mock_apply_copyright_check() as apply_copyright_check: copyright_checker(linter, mock_args) - apply_copyright_check.assert_called_once_with(linter, file_contents(3)) + apply_copyright_check.assert_called_once_with( + linter, "M", "file3.txt", file_contents(3) + ) linter = Linter("file4.txt", file_contents_modified(4)) with mock_apply_copyright_check() as apply_copyright_check: copyright_checker(linter, mock_args) - apply_copyright_check.assert_called_once_with(linter, file_contents(4)) + apply_copyright_check.assert_called_once_with( + linter, "M", "file4.txt", file_contents(4) + ) linter = Linter("file6.txt", file_contents(6)) with mock_apply_copyright_check() as apply_copyright_check: copyright_checker(linter, mock_args) - apply_copyright_check.assert_called_once_with(linter, None) + apply_copyright_check.assert_called_once_with(linter, "A", None, None)