Skip to content

Commit

Permalink
feat: use importlib when deserializing callables
Browse files Browse the repository at this point in the history
  • Loading branch information
LastRemote committed Jan 3, 2025
1 parent 7b4d9ba commit c339c12
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
9 changes: 5 additions & 4 deletions haystack/utils/callable_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# SPDX-License-Identifier: Apache-2.0

import inspect
import sys
from typing import Callable, Optional

from haystack import DeserializationError
from haystack.utils.type_serialization import thread_safe_import


def serialize_callable(callable_handle: Callable) -> str:
Expand Down Expand Up @@ -37,9 +37,10 @@ def deserialize_callable(callable_handle: str) -> Optional[Callable]:
parts = callable_handle.split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module of the callable: {module_name}")
try:
module = thread_safe_import(module_name)
except Exception as e:
raise DeserializationError(f"Could not locate the module of the callable: {module_name}") from e
deserialized_callable = getattr(module, function_name, None)
if not deserialized_callable:
raise DeserializationError(f"Could not locate the callable: {function_name}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
enhancements:
- |
Improved deserialization of callables by using `importlib` instead of `sys.modules`.
This change allows importing local functions and classes that are not in `sys.modules`
when deserializing callables.
9 changes: 9 additions & 0 deletions test/utils/test_callable_serialization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import requests

from haystack import DeserializationError
from haystack.components.generators.utils import print_streaming_chunk
from haystack.utils import serialize_callable, deserialize_callable

Expand Down Expand Up @@ -36,3 +38,10 @@ def test_callable_deserialization_non_local():
result = serialize_callable(requests.api.get)
fn = deserialize_callable(result)
assert fn is requests.api.get


def test_callable_deserialization_error():
with pytest.raises(DeserializationError):
deserialize_callable("this.is.not.a.valid.module")
with pytest.raises(DeserializationError):
deserialize_callable("sys.foobar")

0 comments on commit c339c12

Please sign in to comment.