Skip to content

Commit

Permalink
Added new argument exclusion feature
Browse files Browse the repository at this point in the history
  • Loading branch information
btfranklin committed Nov 11, 2024
1 parent 1b78723 commit c2683c3
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 12 deletions.
72 changes: 63 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ A Python package for caching repeat runs of pipelines that have expensive operat
- **Checkpointing**: Assign checkpoints to pipeline steps to manage caching and recomputation.
- **Cache Truncation**: Remove cached results from a specific checkpoint onwards to recompute parts of the pipeline.
- **Input Sensitivity**: Cache keys are sensitive to function arguments, ensuring that different inputs result in different cache entries.
- **Argument Exclusion**: Exclude specific arguments from the cache key to handle unpickleable objects or sensitive data.
- **Easy Integration**: Minimal changes to your existing codebase are needed to integrate caching.

## Installation
Expand Down Expand Up @@ -68,7 +69,18 @@ def my_function(...):
pass
```

This flexibility allows you to simplify your code and reduce redundancy when the function name suffices as a unique identifier.
### Excluding Arguments from the Cache Key

If your function accepts arguments that are unpickleable or contain sensitive information (like database connections or API clients), you can exclude them from the cache key using the `exclude_args` parameter:

```python
@cache.checkpoint(exclude_args=['unpickleable_arg'])
def my_function(unpickleable_arg, other_arg):
# Your code here
pass
```

- **`exclude_args`**: A list of argument names (as strings) to exclude from the cache key. This is useful when certain arguments cannot be pickled or should not influence caching.

### Building a Pipeline

Expand All @@ -87,28 +99,42 @@ def run_pipeline(user_text):
### Example Functions

```python
@cache.checkpoint(name="step2_enhance_text")
@cache.checkpoint()
def step2_enhance_text(text):
# Simulate an expensive operation
enhanced_text = text.upper()
return enhanced_text

@cache.checkpoint(name="step3_produce_document")
@cache.checkpoint()
def step3_produce_document(enhanced_text):
document = f"Document based on: {enhanced_text}"
return document

@cache.checkpoint(name="step4_generate_additional_documents")
@cache.checkpoint()
def step4_generate_additional_documents(document):
documents = [f"{document} - Version {i}" for i in range(3)]
return documents

@cache.checkpoint(name="step5_summarize_documents")
@cache.checkpoint()
def step5_summarize_documents(documents):
summary = "Summary of documents: " + ", ".join(documents)
return summary
```

### Handling Unpickleable Objects

For functions that require unpickleable objects, such as API clients or database connections, you can exclude these from the cache key:

```python
@cache.checkpoint(exclude_args=['llm_client'])
def enhance_domain(llm_client, domain):
# Use llm_client to perform operations
result = llm_client.process(domain)
return result
```

By excluding `llm_client` from the cache key, you prevent serialization errors and ensure that caching is based only on the relevant arguments.

### Running the Pipeline

```python
Expand Down Expand Up @@ -137,22 +163,37 @@ from pickled_pipeline import Cache

cache = Cache(cache_dir="my_cache_directory")

@cache.checkpoint(name="step1_user_input")
@cache.checkpoint()
def step1_user_input(user_text):
return user_text

@cache.checkpoint(name="step2_enhance_text")
@cache.checkpoint()
def step2_enhance_text(text):
# Simulate an expensive operation
enhanced_text = text.upper()
return enhanced_text

# ... (other steps)
@cache.checkpoint()
def step3_produce_document(enhanced_text):
document = f"Document based on: {enhanced_text}"
return document

@cache.checkpoint()
def step4_generate_additional_documents(document):
documents = [f"{document} - Version {i}" for i in range(3)]
return documents

@cache.checkpoint()
def step5_summarize_documents(documents):
summary = "Summary of documents: " + ", ".join(documents)
return summary

def run_pipeline(user_text):
text = step1_user_input(user_text)
enhanced_text = step2_enhance_text(text)
# ... (other steps)
document = step3_produce_document(enhanced_text)
documents = step4_generate_additional_documents(document)
summary = step5_summarize_documents(documents)
return summary

if __name__ == "__main__":
Expand All @@ -173,6 +214,19 @@ summary1 = run_pipeline("First input from user.")
summary2 = run_pipeline("Second input from user.")
```

### Using Exclude Args in Practice

Suppose you have a function that interacts with an API client:

```python
@cache.checkpoint(exclude_args=['api_client'])
def fetch_data(api_client, endpoint):
response = api_client.get(endpoint)
return response.json()
```

By excluding `api_client` from the cache key, you avoid serialization issues with the client object and ensure that caching is based on the `endpoint` parameter.

## Contributing

Contributions are welcome! Please open an issue or submit a pull request.
Expand Down
20 changes: 17 additions & 3 deletions src/pickled_pipeline/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,30 @@ def __init__(self, cache_dir="pipeline_cache"):
else:
self.checkpoint_order = []

def checkpoint(self, name=None):
def checkpoint(self, name=None, exclude_args=None):
if exclude_args is None:
exclude_args = []

def decorator(func):
checkpoint_name = name or func.__name__

@wraps(func)
def wrapper(*args, **kwargs):
# Create a unique key based on the checkpoint name and function arguments
key_input = (checkpoint_name, args, kwargs)
# Map arguments to their names
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
args_dict = dict(zip(arg_names, args))
args_dict.update(kwargs)

# Remove excluded arguments
for arg in exclude_args:
args_dict.pop(arg, None)

# Create a unique key based on the checkpoint name and filtered arguments
key_input = (checkpoint_name, args_dict)
key_hash = hashlib.md5(pickle.dumps(key_input)).hexdigest()
cache_filename = f"{checkpoint_name}__{key_hash}.pkl"
cache_path = os.path.join(self.cache_dir, cache_filename)

if os.path.exists(cache_path):
with open(cache_path, "rb") as f:
result = pickle.load(f)
Expand All @@ -37,6 +50,7 @@ def wrapper(*args, **kwargs):
with open(cache_path, "wb") as f:
pickle.dump(result, f)
print(f"[{checkpoint_name}] Computed result and saved to cache.")

# Record the checkpoint name if not already recorded
if checkpoint_name not in self.checkpoint_order:
self.checkpoint_order.append(checkpoint_name)
Expand Down

0 comments on commit c2683c3

Please sign in to comment.