diff --git a/nos/server/http/_utils.py b/nos/server/http/_utils.py index d88226e8..3e95973d 100644 --- a/nos/server/http/_utils.py +++ b/nos/server/http/_utils.py @@ -4,12 +4,17 @@ from pathlib import Path from typing import Any, Dict +import msgpack +import msgpack_numpy as m import numpy as np from fastapi import UploadFile from fastapi.responses import FileResponse from PIL import Image +m.patch() + + def encode_item(v: Any) -> Any: """Encode an item to a JSON-serializable object.""" if isinstance(v, dict): @@ -22,7 +27,8 @@ def encode_item(v: Any) -> Any: if v.ndim <= 2: return v.tolist() else: - raise ValueError(f"Unsupported ndarray dimension: {v.ndim}") + arr_b64 = base64.b64encode(msgpack.packb(v)).decode() + return f"data:application/numpy;base64,{arr_b64}" elif isinstance(v, Path): return FileResponse(v) else: @@ -37,6 +43,9 @@ def decode_item(v: Any) -> Any: return [decode_item(x) for x in v] elif isinstance(v, str) and v.startswith("data:image/"): return base64_str_to_image(v) + elif isinstance(v, str) and v.startswith("data:application/numpy;base64,"): + arr_b64 = v[len("data:application/numpy;base64,") :] + return msgpack.unpackb(base64.b64decode(arr_b64), raw=False) else: return v diff --git a/nos/version.py b/nos/version.py index 493f7415..260c070a 100644 --- a/nos/version.py +++ b/nos/version.py @@ -1 +1 @@ -__version__ = "0.3.0" +__version__ = "0.3.1" diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 7036cf30..6567dbbe 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -6,6 +6,8 @@ gitpython grpcio-tools<=1.49.1 humanize loguru>=0.7.0 +msgpack +msgpack-numpy opencv-python-headless>=4.6.0.66 pandas Pillow diff --git a/tests/server/test_server_utils.py b/tests/server/test_server_utils.py index ec18bab1..4b4d9938 100644 --- a/tests/server/test_server_utils.py +++ b/tests/server/test_server_utils.py @@ -1,11 +1,16 @@ import logging +from pathlib import Path +import numpy as np import psutil import pytest +from fastapi.responses import FileResponse +from PIL import Image import nos from nos.common.system import has_docker from nos.server import InferenceServiceRuntime +from nos.server.http._utils import decode_item, encode_item from nos.test.utils import AVAILABLE_RUNTIMES @@ -14,6 +19,42 @@ pytestmark = pytest.mark.skipif(not has_docker() or NUM_CPUS < 4, reason="docker is not installed") +def test_encode_decode_item(): + # Test encoding a dictionary + input_dict = {"key1": "value1", "key2": "value2"} + expected_dict = {"key1": "value1", "key2": "value2"} + assert encode_item(input_dict) == expected_dict + + # Test encoding a list + input_list = [1, 2, 3] + expected_list = [1, 2, 3] + assert encode_item(input_list) == expected_list + + # Test encoding a tuple + input_tuple = (4, 5, 6) + expected_tuple = [4, 5, 6] + assert encode_item(input_tuple) == expected_tuple + + # Test encoding an Image object + input_image = Image.new("RGB", (100, 100)) + assert encode_item(input_image).startswith("data:image/") + assert (decode_item(encode_item(input_image)) == input_image).all() + + # Test encoding an ndarray object + input_ndarray = np.array([1, 2, 3]) + expected_ndarray = [1, 2, 3] + assert encode_item(input_ndarray) == expected_ndarray + + # Test encoding a 3D ndarray object + input_ndarray = np.random.rand(3, 3, 3) + assert encode_item(input_ndarray).startswith("data:application/numpy;base64,") + assert (decode_item(encode_item(input_ndarray)) == input_ndarray).all() + + # Test encoding a Path object + input_path = Path("/path/to/file.txt") + assert isinstance(encode_item(input_path), FileResponse) + + @pytest.mark.parametrize("runtime", AVAILABLE_RUNTIMES) def test_nos_init(runtime): """Test the NOS server daemon initialization.