diff --git a/code2prompt/utils/is_filtered.py b/code2prompt/utils/is_filtered.py index f4da447..a4d5ce1 100644 --- a/code2prompt/utils/is_filtered.py +++ b/code2prompt/utils/is_filtered.py @@ -59,7 +59,7 @@ def prepare_patterns(pattern): # Check exclude patterns first (they take precedence) if match_patterns(file_path_str, exclude_patterns): - return False + return False # Exclude dotfiles and other specified patterns # If include patterns are specified, the file must match at least one if include_patterns: diff --git a/tests/test_create_markdown_with_filter.py b/tests/test_create_markdown_with_filter.py index 99ff3e3..b12e98a 100644 --- a/tests/test_create_markdown_with_filter.py +++ b/tests/test_create_markdown_with_filter.py @@ -1,13 +1,10 @@ -from code2prompt.main import create_markdown_file - +from code2prompt.main import create_markdown_file_command from click.testing import CliRunner - import tempfile from pathlib import Path - def test_create_markdown_with_filter(): runner = CliRunner() with tempfile.TemporaryDirectory() as temp_dir: @@ -19,7 +16,7 @@ def test_create_markdown_with_filter(): filter_option = "*.py" output_file = temp_dir_path / "output_with_filter.md" - result = runner.invoke(create_markdown_file, ['-p', temp_dir, '-o', str(output_file), '-f', filter_option]) + result = runner.invoke(create_markdown_file_command, ['-p', temp_dir, '-o', str(output_file), '-f', filter_option]) assert result.exit_code == 0 assert output_file.exists() diff --git a/tests/test_is_filtered.py b/tests/test_is_filtered.py index 2380eb0..6272b74 100644 --- a/tests/test_is_filtered.py +++ b/tests/test_is_filtered.py @@ -2,6 +2,8 @@ from pathlib import Path from code2prompt.utils.is_filtered import is_filtered +# Removed incorrect import + @pytest.mark.parametrize( "file_path, include_pattern, exclude_pattern, case_sensitive, expected", @@ -24,17 +26,22 @@ (Path("file_without_extension"), "", "*.*", False, True), (Path("deeply/nested/directory/file.txt"), "**/*.txt", "", False, True), (Path("file.txt.bak"), "", "*.bak", False, False), + ( + Path("file.py"), + "syntax_map:*.py", + "", + False, + True, + ), # New test case for syntax map + ( + Path("file.txt"), + "syntax_map:*.py", + "", + False, + False, + ), # New test case for syntax map ], ) -def test_is_filtered( - file_path, include_pattern, exclude_pattern, case_sensitive, expected -): - assert ( - is_filtered(file_path, include_pattern, exclude_pattern, case_sensitive) - == expected - ) - - def test_is_filtered_with_directories(): assert is_filtered( Path("src/test"), "**/test", "", False @@ -53,3 +60,22 @@ def test_is_filtered_case_sensitivity(): def test_is_filtered_exclude_precedence(): assert not is_filtered(Path("important_test.py"), "*.py", "*test*", False) + + +# Define test cases +test_cases = [ + (Path(".gitignore"), "", "**/.gitignore", False), # Should be excluded + (Path(".codetopromptrc"), "", "**/.codetopromptrc", False), # Should be excluded + (Path("README.md"), "", "", True), # Should be included + (Path("notes.txt"), "", "**/*.txt", False), # Should be excluded + (Path("file.py"), "*.py", "", True), # Should be included +] + +# Run tests +for file_path, include, exclude, expected in test_cases: + result = is_filtered(file_path, include, exclude) + assert ( + result == expected + ), f"Test failed for {file_path}: expected {expected}, got {result}" + +print("All tests passed!") diff --git a/tests/test_language_inference.py b/tests/test_language_inference.py index 52e1e59..5d676a4 100644 --- a/tests/test_language_inference.py +++ b/tests/test_language_inference.py @@ -1,27 +1,29 @@ +import pytest from code2prompt.utils.language_inference import infer_language def test_infer_language(): """ Test the infer_language function.""" - assert infer_language("main.c") == "c" - assert infer_language("main.cpp") == "cpp" - assert infer_language("Main.java") == "java" - assert infer_language("script.js") == "javascript" - assert infer_language("Program.cs") == "csharp" - assert infer_language("index.php") == "php" - assert infer_language("main.go") == "go" - assert infer_language("lib.rs") == "rust" - assert infer_language("app.kt") == "kotlin" - assert infer_language("main.swift") == "swift" - assert infer_language("Main.scala") == "scala" - assert infer_language("main.dart") == "dart" - assert infer_language("script.py") == "python" - assert infer_language("script.rb") == "ruby" - assert infer_language("script.pl") == "perl" - assert infer_language("script.sh") == "bash" - assert infer_language("script.ps1") == "powershell" - assert infer_language("index.html") == "html" - assert infer_language("data.xml") == "xml" - assert infer_language("query.sql") == "sql" - assert infer_language("script.m") == "matlab" - assert infer_language("script.r") == "r" - assert infer_language("file.txt") == "plaintext" + syntax_map = {} # Define the syntax map as needed + assert infer_language("main.c", syntax_map) == "c" # Added syntax_map argument + assert infer_language("main.cpp", syntax_map) == "cpp" + assert infer_language("Main.java", syntax_map) == "java" + assert infer_language("script.js", syntax_map) == "javascript" + assert infer_language("Program.cs", syntax_map) == "csharp" + assert infer_language("index.php", syntax_map) == "php" + assert infer_language("main.go", syntax_map) == "go" + assert infer_language("lib.rs", syntax_map) == "rust" + assert infer_language("app.kt", syntax_map) == "kotlin" + assert infer_language("main.swift", syntax_map) == "swift" + assert infer_language("Main.scala", syntax_map) == "scala" + assert infer_language("main.dart", syntax_map) == "dart" + assert infer_language("script.py", syntax_map) == "python" + assert infer_language("script.rb", syntax_map) == "ruby" + assert infer_language("script.pl", syntax_map) == "perl" + assert infer_language("script.sh", syntax_map) == "bash" + assert infer_language("script.ps1", syntax_map) == "powershell" + assert infer_language("index.html", syntax_map) == "html" + assert infer_language("data.xml", syntax_map) == "xml" + assert infer_language("query.sql", syntax_map) == "sql" + assert infer_language("script.m", syntax_map) == "matlab" + assert infer_language("script.r", syntax_map) == "r" + assert infer_language("file.txt", syntax_map) == "plaintext"