diff --git a/src/seer/automation/autofix/tools.py b/src/seer/automation/autofix/tools.py index 2fa50780..f7e523fb 100644 --- a/src/seer/automation/autofix/tools.py +++ b/src/seer/automation/autofix/tools.py @@ -1,4 +1,6 @@ +import fnmatch import logging +import os import textwrap from langfuse.decorators import observe @@ -176,6 +178,46 @@ def keyword_search( return result_str + @observe(name="File Search") + @ai_track(description="File Search") + def file_search( + self, + filename: str, + repo_name: str | None = None, + ): + """ + Given a filename with extension returns the list of locations where a file with the name is found. + """ + repo_client = self.context.get_repo_client(repo_name=repo_name) + all_paths = repo_client.get_index_file_set() + found = [path for path in all_paths if os.path.basename(path) == filename] + if len(found) == 0: + return f"no file with name {filename} found in repository" + + found = sorted(found) + + return ",".join(found) + + @observe(name="File Search Wildcard") + @ai_track(description="File Search Wildcard") + def file_search_wildcard( + self, + pattern: str, + repo_name: str | None = None, + ): + """ + Given a filename pattern with wildcards, returns the list of file paths that match the pattern. + """ + repo_client = self.context.get_repo_client(repo_name=repo_name) + all_paths = repo_client.get_index_file_set() + found = [path for path in all_paths if fnmatch.fnmatch(path, pattern)] + if len(found) == 0: + return f"No files matching pattern '{pattern}' found in repository" + + found = sorted(found) + + return "\n".join(found) + def get_tools(self): tools = [ FunctionTool( @@ -245,6 +287,40 @@ def get_tools(self): }, ], ), + FunctionTool( + name="file_search", + fn=self.file_search, + description="Searches for a file in the codebase.", + parameters=[ + { + "name": "filename", + "type": "string", + "description": "The file to search for.", + }, + { + "name": "repo_name", + "type": "string", + "description": "Optional name of the repository to search in if you know it.", + }, + ], + ), + FunctionTool( + name="file_search_wildcard", + fn=self.file_search_wildcard, + description="Searches for files in a folder using a wildcard pattern.", + parameters=[ + { + "name": "pattern", + "type": "string", + "description": "The wildcard pattern to match files.", + }, + { + "name": "repo_name", + "type": "string", + "description": "Optional name of the repository to search in if you know it.", + }, + ], + ), ] return tools diff --git a/tests/automation/autofix/test_autofix_tools.py b/tests/automation/autofix/test_autofix_tools.py new file mode 100644 index 00000000..a3b9aae9 --- /dev/null +++ b/tests/automation/autofix/test_autofix_tools.py @@ -0,0 +1,79 @@ +from unittest.mock import MagicMock + +import pytest + +from seer.automation.autofix.tools import BaseTools + + +@pytest.fixture +def autofix_tools(): + context = MagicMock() + return BaseTools(context) + + +class TestFileSearch: + def test_file_search_found(self, autofix_tools: BaseTools): + mock_repo_client = MagicMock() + mock_repo_client.get_index_file_set.return_value = { + "src/file1.py", + "tests/file2.py", + "src/subfolder/file2.py", + } + autofix_tools.context.get_repo_client.return_value = mock_repo_client + + result = autofix_tools.file_search("file2.py") + assert result == "src/subfolder/file2.py,tests/file2.py" + + def test_file_search_not_found(self, autofix_tools: BaseTools): + mock_repo_client = MagicMock() + mock_repo_client.get_index_file_set.return_value = { + "src/file1.py", + "tests/file2.py", + "src/subfolder/file3.py", + } + autofix_tools.context.get_repo_client.return_value = mock_repo_client + + result = autofix_tools.file_search("nonexistent.py") + assert result == "no file with name nonexistent.py found in repository" + + def test_file_search_with_repo_name(self, autofix_tools: BaseTools): + mock_repo_client = MagicMock() + mock_repo_client.get_index_file_set.return_value = {"src/file1.py"} + autofix_tools.context.get_repo_client.return_value = mock_repo_client + + autofix_tools.file_search("file1.py", repo_name="test_repo") + autofix_tools.context.get_repo_client.assert_called_once_with(repo_name="test_repo") + + +class TestFileSearchWildcard: + def test_file_search_wildcard_found(self, autofix_tools: BaseTools): + mock_repo_client = MagicMock() + mock_repo_client.get_index_file_set.return_value = { + "src/file1.py", + "tests/test_file1.py", + "src/subfolder/file2.py", + } + autofix_tools.context.get_repo_client.return_value = mock_repo_client + + result = autofix_tools.file_search_wildcard("*.py") + assert result == "src/file1.py\nsrc/subfolder/file2.py\ntests/test_file1.py" + + def test_file_search_wildcard_not_found(self, autofix_tools: BaseTools): + mock_repo_client = MagicMock() + mock_repo_client.get_index_file_set.return_value = { + "src/file1.py", + "tests/test_file1.py", + "src/subfolder/file2.py", + } + autofix_tools.context.get_repo_client.return_value = mock_repo_client + + result = autofix_tools.file_search_wildcard("*.js") + assert result == "No files matching pattern '*.js' found in repository" + + def test_file_search_wildcard_with_repo_name(self, autofix_tools: BaseTools): + mock_repo_client = MagicMock() + mock_repo_client.get_index_file_set.return_value = {"src/file1.py"} + autofix_tools.context.get_repo_client.return_value = mock_repo_client + + autofix_tools.file_search_wildcard("*.py", repo_name="test_repo") + autofix_tools.context.get_repo_client.assert_called_once_with(repo_name="test_repo")