From 54230c52f82fa2e69bd87498db4f024480d9b15b Mon Sep 17 00:00:00 2001 From: "B.T. Franklin" Date: Mon, 11 Nov 2024 21:48:27 -0700 Subject: [PATCH] Updated unit tests to demonstrate argument exclusion feature --- README.md | 2 + pdm.lock | 2 +- pyproject.toml | 2 +- tests/conftest.py | 16 ++ tests/{test_cache.py => test_cache_basic.py} | 123 +++++--------- tests/test_exclude_args.py | 162 +++++++++++++++++++ tests/test_pipeline.py | 32 ++-- tests/test_truncation.py | 65 ++++++++ 8 files changed, 301 insertions(+), 103 deletions(-) create mode 100644 tests/conftest.py rename tests/{test_cache.py => test_cache_basic.py} (56%) create mode 100644 tests/test_exclude_args.py create mode 100644 tests/test_truncation.py diff --git a/README.md b/README.md index 887beb6..4068df8 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,8 @@ def my_function(unpickleable_arg, other_arg): - **`exclude_args`**: A list of argument names (as strings) to exclude from the cache key. This is useful when certain arguments cannot be pickled or should not influence caching. +**Warning**: Excluding arguments that affect the function's output can lead to incorrect caching behavior. The cache will return the result based on the included arguments, ignoring changes in the excluded arguments. Only exclude arguments that do not influence the function's output, such as unpickleable objects or instances that do not affect computation. + ### Building a Pipeline Here's an example of how to build a pipeline using cached functions: diff --git a/pdm.lock b/pdm.lock index 0d17798..08f4535 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:4d45cc6d3ec0051285813d2f2cc845562ef751a732976e62243207290298b823" +content_hash = "sha256:806f2647566ed7037c1d410a138c6f1877e716e869665e7dba594341fa3e35ff" [[metadata.targets]] requires_python = ">=3.10" diff --git a/pyproject.toml b/pyproject.toml index 31bbaa4..d59a00b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ excludes = ["tests/**"] [tool.pdm.dev-dependencies] dev = [ - "pytest>=8.3.1", + "pytest>=8.3.3", "flake8>=7.1.0", ] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..2b5b756 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,16 @@ +import os +import shutil +import pytest +from pickled_pipeline import Cache + +TEST_CACHE_DIR = "test_pipeline_cache" + + +@pytest.fixture(scope="function") +def cache(): + # Set up: Create a Cache instance with a test cache directory + cache = Cache(cache_dir=TEST_CACHE_DIR) + yield cache + # Tear down: Remove the test cache directory after each test + if os.path.exists(TEST_CACHE_DIR): + shutil.rmtree(TEST_CACHE_DIR) diff --git a/tests/test_cache.py b/tests/test_cache_basic.py similarity index 56% rename from tests/test_cache.py rename to tests/test_cache_basic.py index 0f348cd..98b30f8 100644 --- a/tests/test_cache.py +++ b/tests/test_cache_basic.py @@ -1,20 +1,10 @@ +""" +Tests for the basic functionality of the pickled_pipeline.Cache class. +These tests ensure that caching, cache retrieval, and handling of different arguments work as expected. +""" + import os -import shutil import pytest -from pickled_pipeline import Cache - -# Define a temporary directory for caching during tests -TEST_CACHE_DIR = "test_pipeline_cache" - - -@pytest.fixture(scope="function") -def cache(): - # Set up: Create a Cache instance with a test cache directory - cache = Cache(cache_dir=TEST_CACHE_DIR) - yield cache - # Tear down: Remove the test cache directory after each test - if os.path.exists(TEST_CACHE_DIR): - shutil.rmtree(TEST_CACHE_DIR) def test_cache_checkpoint(cache): @@ -26,16 +16,20 @@ def test_function(x): # Call the function for the first time result1 = test_function(3) assert result1 == 9 + # Check that the cache file was created (excluding the manifest) - cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"] + cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"] assert len(cache_files) == 1 # Call the function again with the same argument result2 = test_function(3) assert result2 == 9 + # Ensure the cache file count hasn't increased - cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"] - assert len(cache_files) == 1 + cache_files_after = [ + f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json" + ] + assert len(cache_files_after) == 1 def test_custom_checkpoint_name(cache): @@ -51,7 +45,7 @@ def test_function(x): assert cache.checkpoint_order == ["custom_checkpoint_name"] # Check that the cache file was created with the custom name - cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"] + cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"] assert len(cache_files) == 1 assert cache_files[0].startswith("custom_checkpoint_name__") @@ -61,7 +55,7 @@ def test_function(x): # Ensure no new cache files were created cache_files_after = [ - f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json" + f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json" ] assert len(cache_files_after) == 1 @@ -71,7 +65,7 @@ def test_function(x): # Verify that a new cache file was created for the new input cache_files_final = [ - f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json" + f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json" ] assert len(cache_files_final) == 2 @@ -91,66 +85,10 @@ def test_function(x): assert result3 == 2 # Check that two cache files were created (excluding the manifest) - cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"] + cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"] assert len(cache_files) == 2 -def test_truncate_cache(cache): - # Define functions with arbitrary names - @cache.checkpoint() - def examine_input(): - return "input" - - @cache.checkpoint() - def open_document(): - return "document" - - @cache.checkpoint() - def process_details(): - return "details" - - @cache.checkpoint() - def analyze_result(): - return "result" - - # Run the pipeline - _ = examine_input() - _ = open_document() - _ = process_details() - _ = analyze_result() - - # Check that manifest has the correct order - expected_order = [ - "examine_input", - "open_document", - "process_details", - "analyze_result", - ] - assert cache.checkpoint_order == expected_order - - # Truncate from "open_document" - cache.truncate_cache("open_document") - - # Verify that cache files for "open_document" and subsequent checkpoints are deleted - remaining_checkpoints = cache.list_checkpoints() - assert remaining_checkpoints == ["examine_input"] - - # Verify that cache files are as expected (excluding the manifest) - cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"] - # There should be cache files only for 'examine_input' - assert len(cache_files) == 1 - assert cache_files[0].startswith("examine_input__") - - # Re-run the truncated steps - _ = open_document() - _ = process_details() - _ = analyze_result() - - # Verify that the cache is rebuilt - remaining_checkpoints = cache.list_checkpoints() - assert remaining_checkpoints == expected_order - - def test_cache_with_complex_arguments(cache): @cache.checkpoint() def complex_function(a, b): @@ -176,6 +114,31 @@ def non_pickleable_function(): non_pickleable_function() +def test_cache_with_non_serializable_included_arg(cache): + import threading + + unpickleable_arg = threading.Lock() + + @cache.checkpoint() + def test_function(unpickleable_arg): + return "result" + + # This should raise an exception because the argument cannot be pickled + with pytest.raises(Exception): + test_function(unpickleable_arg) + + +def test_cache_error_with_unpickleable_return(cache): + @cache.checkpoint() + def test_function(): + import threading + + return threading.Lock() # Unpickleable return value + + with pytest.raises(Exception): + test_function() + + def test_clear_cache(cache): @cache.checkpoint() def step1(): @@ -190,7 +153,7 @@ def step2(): _ = step2() # Ensure cache files are created (excluding manifest) - cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"] + cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"] assert len(cache_files) == 2 # Clear the cache @@ -198,7 +161,7 @@ def step2(): # Verify that cache files are deleted (excluding manifest) cache_files_after_clear = [ - f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json" + f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json" ] assert len(cache_files_after_clear) == 0 diff --git a/tests/test_exclude_args.py b/tests/test_exclude_args.py new file mode 100644 index 0000000..c0ac123 --- /dev/null +++ b/tests/test_exclude_args.py @@ -0,0 +1,162 @@ +""" +Tests for the 'exclude_args' feature of the pickled_pipeline.Cache class. +These tests ensure that arguments can be excluded from the cache key, +allowing for unpickleable objects or sensitive data to be passed without +affecting caching behavior. +""" + +import os +import threading +import pytest + + +def test_cache_with_excluded_unpickleable_argument(cache): + # Define an unpickleable object + unpickleable_arg = threading.Lock() + + @cache.checkpoint(exclude_args=["unpickleable_arg"]) + def test_function(x, unpickleable_arg): + return x * 2 + + # Call the function + result = test_function(5, unpickleable_arg) + assert result == 10 + + # Verify that the cache file was created + cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"] + assert len(cache_files) == 1 + + # Call the function again with the same 'x' but a different 'unpickleable_arg' + new_unpickleable_arg = threading.Lock() + result_cached = test_function(5, new_unpickleable_arg) + assert result_cached == 10 + + # Ensure that the cached result was used (no new cache file created) + cache_files_after = [ + f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json" + ] + assert len(cache_files_after) == 1 + + +def test_cache_excluded_argument_affects_result(cache): + @cache.checkpoint(exclude_args=["excluded_arg"]) + def test_function(x, excluded_arg): + return x * excluded_arg + + # Call the function with different 'excluded_arg' values + result1 = test_function(5, 2) + result2 = test_function(5, 3) + + # Since 'excluded_arg' is excluded from the cache key, both calls should retrieve the same cached result + # This demonstrates that excluding arguments affecting the output can lead to incorrect caching + assert result1 == result2 == 10 # Both results are from the first computation + + +def test_cache_with_multiple_excluded_arguments(cache): + unpickleable_arg1 = threading.Lock() + unpickleable_arg2 = threading.Lock() + + @cache.checkpoint(exclude_args=["unpickleable_arg1", "unpickleable_arg2"]) + def test_function(x, unpickleable_arg1, unpickleable_arg2): + return x * 2 + + # Call the function + result = test_function(5, unpickleable_arg1, unpickleable_arg2) + assert result == 10 + + # Verify that the cache file was created + cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"] + assert len(cache_files) == 1 + + +def test_cache_included_arguments_affect_cache(cache): + @cache.checkpoint(exclude_args=["excluded_arg"]) + def test_function(x, excluded_arg): + return x + excluded_arg + + # Call with x=5 + result1 = test_function(5, 10) + assert result1 == 15 + + # Call with x=6 + result2 = test_function(6, 10) + assert result2 == 16 + + # Verify that two cache files were created since 'x' is included in the cache key + cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"] + assert len(cache_files) == 2 + + +def test_cache_excluding_nonexistent_argument(cache): + @cache.checkpoint(exclude_args=["nonexistent_arg"]) + def test_function(x): + return x + 1 + + # Call the function + result = test_function(5) + assert result == 6 + + # Verify that cache works even when excluding a non-existent argument + cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"] + assert len(cache_files) == 1 + + +def test_cache_with_excluded_kwargs(cache): + @cache.checkpoint(exclude_args=["excluded_kwarg"]) + def test_function(x, **kwargs): + return x + kwargs.get("excluded_kwarg", 0) + + # Call the function with different 'excluded_kwarg' values + result1 = test_function(5, excluded_kwarg=10) + result2 = test_function(5, excluded_kwarg=20) + + # Since 'excluded_kwarg' is excluded, both results should be cached as the same + assert result1 == result2 == 15 + + +def test_cache_with_args_and_excluded_args(cache): + @cache.checkpoint(exclude_args=["excluded_arg"]) + def test_function(*args, **kwargs): + return sum(args) + sum(kwargs.values()) + + # Call the function + result = test_function(1, 2, excluded_arg=3, included_arg=4) + assert result == 10 # 1 + 2 + 3 + 4 + + # Verify that 'excluded_arg' does not affect the cache key + cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"] + assert len(cache_files) == 1 + + # Call again with a different 'excluded_arg' + result_cached = test_function(1, 2, excluded_arg=5, included_arg=4) + assert result_cached == 10 # Cached result from the first call + + # Demonstrate that excluding arguments affecting the output can lead to incorrect caching + assert result_cached != 12 # The result is not updated due to caching + + +def test_cache_with_default_arguments(cache): + @cache.checkpoint(exclude_args=["excluded_arg"]) + def test_function(x, excluded_arg="default"): + return f"{x}_{excluded_arg}" + + # Call the function without specifying 'excluded_arg' + result1 = test_function(5) + assert result1 == "5_default" + + # Call the function with a different 'excluded_arg' + result2 = test_function(5, excluded_arg="changed") + assert result2 == "5_default" # Cached result from the first call + + # Demonstrate that excluding arguments affecting the output can lead to incorrect caching + assert result2 != "5_changed" # The result is not updated due to caching + + +def test_cache_with_unpickleable_return_and_excluded_args(cache): + @cache.checkpoint(exclude_args=["x"]) + def test_function(x): + return threading.Lock() # Unpickleable return value + + # Even though 'x' is excluded, the return value is unpickleable + with pytest.raises(Exception): + test_function(5) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index aee0931..da7912f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,23 +1,13 @@ -import os -import shutil -import pytest -from pickled_pipeline import Cache - -TEST_CACHE_DIR = "test_pipeline_cache" - +""" +Tests that focus on how the caching mechanism integrates with a pipeline of functions, +ensuring that the cache works correctly in a real-world scenario involving multiple steps. +""" -@pytest.fixture(scope="function") -def cache(): - # Set up: Create a Cache instance with a test cache directory - cache = Cache(cache_dir=TEST_CACHE_DIR) - yield cache - # Tear down: Remove the test cache directory after each test - if os.path.exists(TEST_CACHE_DIR): - shutil.rmtree(TEST_CACHE_DIR) +import os def test_pipeline(cache): - # Re-define the pipeline functions using the test cache + # Define the pipeline functions using the test cache @cache.checkpoint() def step1_user_input(user_text): return user_text @@ -63,7 +53,7 @@ def run_pipeline(user_text): assert summary == expected_summary # Verify that cache files were created (excluding manifest) - cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"] + cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"] assert len(cache_files) == 5 # Truncate the cache from step3 onwards @@ -71,7 +61,7 @@ def run_pipeline(user_text): # Ensure that only two cache files remain (excluding manifest) cache_files_after_truncate = [ - f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json" + f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json" ] assert len(cache_files_after_truncate) == 2 @@ -81,7 +71,7 @@ def run_pipeline(user_text): # Verify that all cache files are recreated (excluding manifest) cache_files_final = [ - f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json" + f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json" ] assert len(cache_files_final) == 5 @@ -126,7 +116,7 @@ def run_pipeline(user_text): # Verify that cache files were created (excluding manifest) cache_files_after_first_run = [ - f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json" + f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json" ] num_cache_files_first_run = len(cache_files_after_first_run) assert num_cache_files_first_run == 5 @@ -137,7 +127,7 @@ def run_pipeline(user_text): # Verify that new cache files were created for the new input (excluding manifest) cache_files_after_second_run = [ - f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json" + f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json" ] num_cache_files_second_run = len(cache_files_after_second_run) assert num_cache_files_second_run == 10 # Should have 5 new cache files diff --git a/tests/test_truncation.py b/tests/test_truncation.py new file mode 100644 index 0000000..4f16ce0 --- /dev/null +++ b/tests/test_truncation.py @@ -0,0 +1,65 @@ +""" +Tests for the cache truncation functionality of the Cache class. +Ensures cached results can be invalidated from a specified checkpoint +and verifies that the cache can be correctly rebuilt afterward. +""" + +import os + + +def test_truncate_cache(cache): + # Define functions with arbitrary names + @cache.checkpoint() + def examine_input(): + return "input" + + @cache.checkpoint() + def open_document(): + return "document" + + @cache.checkpoint() + def process_details(): + return "details" + + @cache.checkpoint() + def analyze_result(): + return "result" + + # Run the pipeline + _ = examine_input() + _ = open_document() + _ = process_details() + _ = analyze_result() + + # Check that manifest has the correct order + expected_order = [ + "examine_input", + "open_document", + "process_details", + "analyze_result", + ] + assert cache.checkpoint_order == expected_order + + # Truncate from "open_document" + cache.truncate_cache("open_document") + + # Verify that cache files for "open_document" and subsequent checkpoints are deleted + remaining_checkpoints = cache.list_checkpoints() + assert remaining_checkpoints == ["examine_input"] + + # Verify that cache files are as expected (excluding the manifest) + cache_dir = cache.cache_dir # Access the cache directory from the cache instance + cache_files = [f for f in os.listdir(cache_dir) if f != "cache_manifest.json"] + + # There should be cache files only for 'examine_input' + assert len(cache_files) == 1 + assert cache_files[0].startswith("examine_input__") + + # Re-run the truncated steps + _ = open_document() + _ = process_details() + _ = analyze_result() + + # Verify that the cache is rebuilt + remaining_checkpoints = cache.list_checkpoints() + assert remaining_checkpoints == expected_order