diff --git a/README.md b/README.md index fbbdde8..887beb6 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ A Python package for caching repeat runs of pipelines that have expensive operat - **Checkpointing**: Assign checkpoints to pipeline steps to manage caching and recomputation. - **Cache Truncation**: Remove cached results from a specific checkpoint onwards to recompute parts of the pipeline. - **Input Sensitivity**: Cache keys are sensitive to function arguments, ensuring that different inputs result in different cache entries. +- **Argument Exclusion**: Exclude specific arguments from the cache key to handle unpickleable objects or sensitive data. - **Easy Integration**: Minimal changes to your existing codebase are needed to integrate caching. ## Installation @@ -68,7 +69,18 @@ def my_function(...): pass ``` -This flexibility allows you to simplify your code and reduce redundancy when the function name suffices as a unique identifier. +### Excluding Arguments from the Cache Key + +If your function accepts arguments that are unpickleable or contain sensitive information (like database connections or API clients), you can exclude them from the cache key using the `exclude_args` parameter: + +```python +@cache.checkpoint(exclude_args=['unpickleable_arg']) +def my_function(unpickleable_arg, other_arg): + # Your code here + pass +``` + +- **`exclude_args`**: A list of argument names (as strings) to exclude from the cache key. This is useful when certain arguments cannot be pickled or should not influence caching. ### Building a Pipeline @@ -87,28 +99,42 @@ def run_pipeline(user_text): ### Example Functions ```python -@cache.checkpoint(name="step2_enhance_text") +@cache.checkpoint() def step2_enhance_text(text): # Simulate an expensive operation enhanced_text = text.upper() return enhanced_text -@cache.checkpoint(name="step3_produce_document") +@cache.checkpoint() def step3_produce_document(enhanced_text): document = f"Document based on: {enhanced_text}" return document -@cache.checkpoint(name="step4_generate_additional_documents") +@cache.checkpoint() def step4_generate_additional_documents(document): documents = [f"{document} - Version {i}" for i in range(3)] return documents -@cache.checkpoint(name="step5_summarize_documents") +@cache.checkpoint() def step5_summarize_documents(documents): summary = "Summary of documents: " + ", ".join(documents) return summary ``` +### Handling Unpickleable Objects + +For functions that require unpickleable objects, such as API clients or database connections, you can exclude these from the cache key: + +```python +@cache.checkpoint(exclude_args=['llm_client']) +def enhance_domain(llm_client, domain): + # Use llm_client to perform operations + result = llm_client.process(domain) + return result +``` + +By excluding `llm_client` from the cache key, you prevent serialization errors and ensure that caching is based only on the relevant arguments. + ### Running the Pipeline ```python @@ -137,22 +163,37 @@ from pickled_pipeline import Cache cache = Cache(cache_dir="my_cache_directory") -@cache.checkpoint(name="step1_user_input") +@cache.checkpoint() def step1_user_input(user_text): return user_text -@cache.checkpoint(name="step2_enhance_text") +@cache.checkpoint() def step2_enhance_text(text): # Simulate an expensive operation enhanced_text = text.upper() return enhanced_text -# ... (other steps) +@cache.checkpoint() +def step3_produce_document(enhanced_text): + document = f"Document based on: {enhanced_text}" + return document + +@cache.checkpoint() +def step4_generate_additional_documents(document): + documents = [f"{document} - Version {i}" for i in range(3)] + return documents + +@cache.checkpoint() +def step5_summarize_documents(documents): + summary = "Summary of documents: " + ", ".join(documents) + return summary def run_pipeline(user_text): text = step1_user_input(user_text) enhanced_text = step2_enhance_text(text) - # ... (other steps) + document = step3_produce_document(enhanced_text) + documents = step4_generate_additional_documents(document) + summary = step5_summarize_documents(documents) return summary if __name__ == "__main__": @@ -173,6 +214,19 @@ summary1 = run_pipeline("First input from user.") summary2 = run_pipeline("Second input from user.") ``` +### Using Exclude Args in Practice + +Suppose you have a function that interacts with an API client: + +```python +@cache.checkpoint(exclude_args=['api_client']) +def fetch_data(api_client, endpoint): + response = api_client.get(endpoint) + return response.json() +``` + +By excluding `api_client` from the cache key, you avoid serialization issues with the client object and ensure that caching is based on the `endpoint` parameter. + ## Contributing Contributions are welcome! Please open an issue or submit a pull request. diff --git a/src/pickled_pipeline/cache.py b/src/pickled_pipeline/cache.py index 5b74509..6ca0af9 100644 --- a/src/pickled_pipeline/cache.py +++ b/src/pickled_pipeline/cache.py @@ -17,17 +17,30 @@ def __init__(self, cache_dir="pipeline_cache"): else: self.checkpoint_order = [] - def checkpoint(self, name=None): + def checkpoint(self, name=None, exclude_args=None): + if exclude_args is None: + exclude_args = [] + def decorator(func): checkpoint_name = name or func.__name__ @wraps(func) def wrapper(*args, **kwargs): - # Create a unique key based on the checkpoint name and function arguments - key_input = (checkpoint_name, args, kwargs) + # Map arguments to their names + arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] + args_dict = dict(zip(arg_names, args)) + args_dict.update(kwargs) + + # Remove excluded arguments + for arg in exclude_args: + args_dict.pop(arg, None) + + # Create a unique key based on the checkpoint name and filtered arguments + key_input = (checkpoint_name, args_dict) key_hash = hashlib.md5(pickle.dumps(key_input)).hexdigest() cache_filename = f"{checkpoint_name}__{key_hash}.pkl" cache_path = os.path.join(self.cache_dir, cache_filename) + if os.path.exists(cache_path): with open(cache_path, "rb") as f: result = pickle.load(f) @@ -37,6 +50,7 @@ def wrapper(*args, **kwargs): with open(cache_path, "wb") as f: pickle.dump(result, f) print(f"[{checkpoint_name}] Computed result and saved to cache.") + # Record the checkpoint name if not already recorded if checkpoint_name not in self.checkpoint_order: self.checkpoint_order.append(checkpoint_name)