Skip to content

Commit

Permalink
fix(code2prompt/utils): update is_filtered function to handle exclude…
Browse files Browse the repository at this point in the history
… patterns, add tests for syntax map (#19)
  • Loading branch information
raphaelmansuy authored Sep 8, 2024
1 parent 9e17676 commit 3b377ba
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 38 deletions.
2 changes: 1 addition & 1 deletion code2prompt/utils/is_filtered.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def prepare_patterns(pattern):

# Check exclude patterns first (they take precedence)
if match_patterns(file_path_str, exclude_patterns):
return False
return False # Exclude dotfiles and other specified patterns

# If include patterns are specified, the file must match at least one
if include_patterns:
Expand Down
7 changes: 2 additions & 5 deletions tests/test_create_markdown_with_filter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from code2prompt.main import create_markdown_file

from code2prompt.main import create_markdown_file_command

from click.testing import CliRunner


import tempfile
from pathlib import Path


def test_create_markdown_with_filter():
runner = CliRunner()
with tempfile.TemporaryDirectory() as temp_dir:
Expand All @@ -19,7 +16,7 @@ def test_create_markdown_with_filter():

filter_option = "*.py"
output_file = temp_dir_path / "output_with_filter.md"
result = runner.invoke(create_markdown_file, ['-p', temp_dir, '-o', str(output_file), '-f', filter_option])
result = runner.invoke(create_markdown_file_command, ['-p', temp_dir, '-o', str(output_file), '-f', filter_option])

assert result.exit_code == 0
assert output_file.exists()
Expand Down
44 changes: 35 additions & 9 deletions tests/test_is_filtered.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pathlib import Path
from code2prompt.utils.is_filtered import is_filtered

# Removed incorrect import


@pytest.mark.parametrize(
"file_path, include_pattern, exclude_pattern, case_sensitive, expected",
Expand All @@ -24,17 +26,22 @@
(Path("file_without_extension"), "", "*.*", False, True),
(Path("deeply/nested/directory/file.txt"), "**/*.txt", "", False, True),
(Path("file.txt.bak"), "", "*.bak", False, False),
(
Path("file.py"),
"syntax_map:*.py",
"",
False,
True,
), # New test case for syntax map
(
Path("file.txt"),
"syntax_map:*.py",
"",
False,
False,
), # New test case for syntax map
],
)
def test_is_filtered(
file_path, include_pattern, exclude_pattern, case_sensitive, expected
):
assert (
is_filtered(file_path, include_pattern, exclude_pattern, case_sensitive)
== expected
)


def test_is_filtered_with_directories():
assert is_filtered(
Path("src/test"), "**/test", "", False
Expand All @@ -53,3 +60,22 @@ def test_is_filtered_case_sensitivity():

def test_is_filtered_exclude_precedence():
assert not is_filtered(Path("important_test.py"), "*.py", "*test*", False)


# Define test cases
test_cases = [
(Path(".gitignore"), "", "**/.gitignore", False), # Should be excluded
(Path(".codetopromptrc"), "", "**/.codetopromptrc", False), # Should be excluded
(Path("README.md"), "", "", True), # Should be included
(Path("notes.txt"), "", "**/*.txt", False), # Should be excluded
(Path("file.py"), "*.py", "", True), # Should be included
]

# Run tests
for file_path, include, exclude, expected in test_cases:
result = is_filtered(file_path, include, exclude)
assert (
result == expected
), f"Test failed for {file_path}: expected {expected}, got {result}"

print("All tests passed!")
48 changes: 25 additions & 23 deletions tests/test_language_inference.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
import pytest
from code2prompt.utils.language_inference import infer_language

def test_infer_language():
""" Test the infer_language function."""
assert infer_language("main.c") == "c"
assert infer_language("main.cpp") == "cpp"
assert infer_language("Main.java") == "java"
assert infer_language("script.js") == "javascript"
assert infer_language("Program.cs") == "csharp"
assert infer_language("index.php") == "php"
assert infer_language("main.go") == "go"
assert infer_language("lib.rs") == "rust"
assert infer_language("app.kt") == "kotlin"
assert infer_language("main.swift") == "swift"
assert infer_language("Main.scala") == "scala"
assert infer_language("main.dart") == "dart"
assert infer_language("script.py") == "python"
assert infer_language("script.rb") == "ruby"
assert infer_language("script.pl") == "perl"
assert infer_language("script.sh") == "bash"
assert infer_language("script.ps1") == "powershell"
assert infer_language("index.html") == "html"
assert infer_language("data.xml") == "xml"
assert infer_language("query.sql") == "sql"
assert infer_language("script.m") == "matlab"
assert infer_language("script.r") == "r"
assert infer_language("file.txt") == "plaintext"
syntax_map = {} # Define the syntax map as needed
assert infer_language("main.c", syntax_map) == "c" # Added syntax_map argument
assert infer_language("main.cpp", syntax_map) == "cpp"
assert infer_language("Main.java", syntax_map) == "java"
assert infer_language("script.js", syntax_map) == "javascript"
assert infer_language("Program.cs", syntax_map) == "csharp"
assert infer_language("index.php", syntax_map) == "php"
assert infer_language("main.go", syntax_map) == "go"
assert infer_language("lib.rs", syntax_map) == "rust"
assert infer_language("app.kt", syntax_map) == "kotlin"
assert infer_language("main.swift", syntax_map) == "swift"
assert infer_language("Main.scala", syntax_map) == "scala"
assert infer_language("main.dart", syntax_map) == "dart"
assert infer_language("script.py", syntax_map) == "python"
assert infer_language("script.rb", syntax_map) == "ruby"
assert infer_language("script.pl", syntax_map) == "perl"
assert infer_language("script.sh", syntax_map) == "bash"
assert infer_language("script.ps1", syntax_map) == "powershell"
assert infer_language("index.html", syntax_map) == "html"
assert infer_language("data.xml", syntax_map) == "xml"
assert infer_language("query.sql", syntax_map) == "sql"
assert infer_language("script.m", syntax_map) == "matlab"
assert infer_language("script.r", syntax_map) == "r"
assert infer_language("file.txt", syntax_map) == "plaintext"

0 comments on commit 3b377ba

Please sign in to comment.