Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(autofix/codegen): File search tools #1134

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions src/seer/automation/autofix/tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import fnmatch
import logging
import os
import textwrap

from langfuse.decorators import observe
Expand Down Expand Up @@ -176,6 +178,46 @@ def keyword_search(

return result_str

@observe(name="File Search")
@ai_track(description="File Search")
def file_search(
self,
filename: str,
repo_name: str | None = None,
):
"""
Given a filename with extension returns the list of locations where a file with the name is found.
"""
repo_client = self.context.get_repo_client(repo_name=repo_name)
all_paths = repo_client.get_index_file_set()
found = [path for path in all_paths if os.path.basename(path) == filename]
if len(found) == 0:
return f"no file with name {filename} found in repository"

found = sorted(found)

return ",".join(found)

@observe(name="File Search Wildcard")
@ai_track(description="File Search Wildcard")
def file_search_wildcard(
self,
pattern: str,
repo_name: str | None = None,
):
"""
Given a filename pattern with wildcards, returns the list of file paths that match the pattern.
"""
repo_client = self.context.get_repo_client(repo_name=repo_name)
all_paths = repo_client.get_index_file_set()
found = [path for path in all_paths if fnmatch.fnmatch(path, pattern)]
if len(found) == 0:
return f"No files matching pattern '{pattern}' found in repository"

found = sorted(found)

return "\n".join(found)

def get_tools(self):
tools = [
FunctionTool(
Expand Down Expand Up @@ -245,6 +287,40 @@ def get_tools(self):
},
],
),
FunctionTool(
name="file_search",
fn=self.file_search,
description="Searches for a file in the codebase.",
parameters=[
{
"name": "filename",
"type": "string",
"description": "The file to search for.",
},
{
"name": "repo_name",
"type": "string",
"description": "Optional name of the repository to search in if you know it.",
},
],
),
FunctionTool(
name="file_search_wildcard",
fn=self.file_search_wildcard,
description="Searches for files in a folder using a wildcard pattern.",
parameters=[
{
"name": "pattern",
"type": "string",
"description": "The wildcard pattern to match files.",
},
{
"name": "repo_name",
"type": "string",
"description": "Optional name of the repository to search in if you know it.",
},
],
),
]

return tools
79 changes: 79 additions & 0 deletions tests/automation/autofix/test_autofix_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from unittest.mock import MagicMock

import pytest

from seer.automation.autofix.tools import BaseTools


@pytest.fixture
def autofix_tools():
context = MagicMock()
return BaseTools(context)


class TestFileSearch:
def test_file_search_found(self, autofix_tools: BaseTools):
mock_repo_client = MagicMock()
mock_repo_client.get_index_file_set.return_value = {
"src/file1.py",
"tests/file2.py",
"src/subfolder/file2.py",
}
autofix_tools.context.get_repo_client.return_value = mock_repo_client

result = autofix_tools.file_search("file2.py")
assert result == "src/subfolder/file2.py,tests/file2.py"

def test_file_search_not_found(self, autofix_tools: BaseTools):
mock_repo_client = MagicMock()
mock_repo_client.get_index_file_set.return_value = {
"src/file1.py",
"tests/file2.py",
"src/subfolder/file3.py",
}
autofix_tools.context.get_repo_client.return_value = mock_repo_client

result = autofix_tools.file_search("nonexistent.py")
assert result == "no file with name nonexistent.py found in repository"

def test_file_search_with_repo_name(self, autofix_tools: BaseTools):
mock_repo_client = MagicMock()
mock_repo_client.get_index_file_set.return_value = {"src/file1.py"}
autofix_tools.context.get_repo_client.return_value = mock_repo_client

autofix_tools.file_search("file1.py", repo_name="test_repo")
autofix_tools.context.get_repo_client.assert_called_once_with(repo_name="test_repo")


class TestFileSearchWildcard:
def test_file_search_wildcard_found(self, autofix_tools: BaseTools):
mock_repo_client = MagicMock()
mock_repo_client.get_index_file_set.return_value = {
"src/file1.py",
"tests/test_file1.py",
"src/subfolder/file2.py",
}
autofix_tools.context.get_repo_client.return_value = mock_repo_client

result = autofix_tools.file_search_wildcard("*.py")
assert result == "src/file1.py\nsrc/subfolder/file2.py\ntests/test_file1.py"

def test_file_search_wildcard_not_found(self, autofix_tools: BaseTools):
mock_repo_client = MagicMock()
mock_repo_client.get_index_file_set.return_value = {
"src/file1.py",
"tests/test_file1.py",
"src/subfolder/file2.py",
}
autofix_tools.context.get_repo_client.return_value = mock_repo_client

result = autofix_tools.file_search_wildcard("*.js")
assert result == "No files matching pattern '*.js' found in repository"

def test_file_search_wildcard_with_repo_name(self, autofix_tools: BaseTools):
mock_repo_client = MagicMock()
mock_repo_client.get_index_file_set.return_value = {"src/file1.py"}
autofix_tools.context.get_repo_client.return_value = mock_repo_client

autofix_tools.file_search_wildcard("*.py", repo_name="test_repo")
autofix_tools.context.get_repo_client.assert_called_once_with(repo_name="test_repo")
Loading