Skip to content

Commit

Permalink
Updated unit tests to demonstrate argument exclusion feature
Browse files Browse the repository at this point in the history
  • Loading branch information
btfranklin committed Nov 12, 2024
1 parent c2683c3 commit 54230c5
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 103 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def my_function(unpickleable_arg, other_arg):

- **`exclude_args`**: A list of argument names (as strings) to exclude from the cache key. This is useful when certain arguments cannot be pickled or should not influence caching.

**Warning**: Excluding arguments that affect the function's output can lead to incorrect caching behavior. The cache will return the result based on the included arguments, ignoring changes in the excluded arguments. Only exclude arguments that do not influence the function's output, such as unpickleable objects or instances that do not affect computation.

### Building a Pipeline

Here's an example of how to build a pipeline using cached functions:
Expand Down
2 changes: 1 addition & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ excludes = ["tests/**"]

[tool.pdm.dev-dependencies]
dev = [
"pytest>=8.3.1",
"pytest>=8.3.3",
"flake8>=7.1.0",
]

Expand Down
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
import shutil
import pytest
from pickled_pipeline import Cache

TEST_CACHE_DIR = "test_pipeline_cache"


@pytest.fixture(scope="function")
def cache():
# Set up: Create a Cache instance with a test cache directory
cache = Cache(cache_dir=TEST_CACHE_DIR)
yield cache
# Tear down: Remove the test cache directory after each test
if os.path.exists(TEST_CACHE_DIR):
shutil.rmtree(TEST_CACHE_DIR)
123 changes: 43 additions & 80 deletions tests/test_cache.py → tests/test_cache_basic.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
"""
Tests for the basic functionality of the pickled_pipeline.Cache class.
These tests ensure that caching, cache retrieval, and handling of different arguments work as expected.
"""

import os
import shutil
import pytest
from pickled_pipeline import Cache

# Define a temporary directory for caching during tests
TEST_CACHE_DIR = "test_pipeline_cache"


@pytest.fixture(scope="function")
def cache():
# Set up: Create a Cache instance with a test cache directory
cache = Cache(cache_dir=TEST_CACHE_DIR)
yield cache
# Tear down: Remove the test cache directory after each test
if os.path.exists(TEST_CACHE_DIR):
shutil.rmtree(TEST_CACHE_DIR)


def test_cache_checkpoint(cache):
Expand All @@ -26,16 +16,20 @@ def test_function(x):
# Call the function for the first time
result1 = test_function(3)
assert result1 == 9

# Check that the cache file was created (excluding the manifest)
cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"]
cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"]
assert len(cache_files) == 1

# Call the function again with the same argument
result2 = test_function(3)
assert result2 == 9

# Ensure the cache file count hasn't increased
cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"]
assert len(cache_files) == 1
cache_files_after = [
f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"
]
assert len(cache_files_after) == 1


def test_custom_checkpoint_name(cache):
Expand All @@ -51,7 +45,7 @@ def test_function(x):
assert cache.checkpoint_order == ["custom_checkpoint_name"]

# Check that the cache file was created with the custom name
cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"]
cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"]
assert len(cache_files) == 1
assert cache_files[0].startswith("custom_checkpoint_name__")

Expand All @@ -61,7 +55,7 @@ def test_function(x):

# Ensure no new cache files were created
cache_files_after = [
f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"
f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"
]
assert len(cache_files_after) == 1

Expand All @@ -71,7 +65,7 @@ def test_function(x):

# Verify that a new cache file was created for the new input
cache_files_final = [
f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"
f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"
]
assert len(cache_files_final) == 2

Expand All @@ -91,66 +85,10 @@ def test_function(x):
assert result3 == 2

# Check that two cache files were created (excluding the manifest)
cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"]
cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"]
assert len(cache_files) == 2


def test_truncate_cache(cache):
# Define functions with arbitrary names
@cache.checkpoint()
def examine_input():
return "input"

@cache.checkpoint()
def open_document():
return "document"

@cache.checkpoint()
def process_details():
return "details"

@cache.checkpoint()
def analyze_result():
return "result"

# Run the pipeline
_ = examine_input()
_ = open_document()
_ = process_details()
_ = analyze_result()

# Check that manifest has the correct order
expected_order = [
"examine_input",
"open_document",
"process_details",
"analyze_result",
]
assert cache.checkpoint_order == expected_order

# Truncate from "open_document"
cache.truncate_cache("open_document")

# Verify that cache files for "open_document" and subsequent checkpoints are deleted
remaining_checkpoints = cache.list_checkpoints()
assert remaining_checkpoints == ["examine_input"]

# Verify that cache files are as expected (excluding the manifest)
cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"]
# There should be cache files only for 'examine_input'
assert len(cache_files) == 1
assert cache_files[0].startswith("examine_input__")

# Re-run the truncated steps
_ = open_document()
_ = process_details()
_ = analyze_result()

# Verify that the cache is rebuilt
remaining_checkpoints = cache.list_checkpoints()
assert remaining_checkpoints == expected_order


def test_cache_with_complex_arguments(cache):
@cache.checkpoint()
def complex_function(a, b):
Expand All @@ -176,6 +114,31 @@ def non_pickleable_function():
non_pickleable_function()


def test_cache_with_non_serializable_included_arg(cache):
import threading

unpickleable_arg = threading.Lock()

@cache.checkpoint()
def test_function(unpickleable_arg):
return "result"

# This should raise an exception because the argument cannot be pickled
with pytest.raises(Exception):
test_function(unpickleable_arg)


def test_cache_error_with_unpickleable_return(cache):
@cache.checkpoint()
def test_function():
import threading

return threading.Lock() # Unpickleable return value

with pytest.raises(Exception):
test_function()


def test_clear_cache(cache):
@cache.checkpoint()
def step1():
Expand All @@ -190,15 +153,15 @@ def step2():
_ = step2()

# Ensure cache files are created (excluding manifest)
cache_files = [f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"]
cache_files = [f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"]
assert len(cache_files) == 2

# Clear the cache
cache.clear_cache()

# Verify that cache files are deleted (excluding manifest)
cache_files_after_clear = [
f for f in os.listdir(TEST_CACHE_DIR) if f != "cache_manifest.json"
f for f in os.listdir(cache.cache_dir) if f != "cache_manifest.json"
]
assert len(cache_files_after_clear) == 0

Expand Down
Loading

0 comments on commit 54230c5

Please sign in to comment.