Skip to content

Commit

Permalink
2+ dimensional numpy array serde with msgpack
Browse files Browse the repository at this point in the history
  • Loading branch information
spillai committed May 16, 2024
1 parent 20cad80 commit fbaacbd
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 2 deletions.
11 changes: 10 additions & 1 deletion nos/server/http/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
from pathlib import Path
from typing import Any, Dict

import msgpack
import msgpack_numpy as m
import numpy as np
from fastapi import UploadFile
from fastapi.responses import FileResponse
from PIL import Image


m.patch()


def encode_item(v: Any) -> Any:
"""Encode an item to a JSON-serializable object."""
if isinstance(v, dict):
Expand All @@ -22,7 +27,8 @@ def encode_item(v: Any) -> Any:
if v.ndim <= 2:
return v.tolist()
else:
raise ValueError(f"Unsupported ndarray dimension: {v.ndim}")
arr_b64 = base64.b64encode(msgpack.packb(v)).decode()
return f"data:application/numpy;base64,{arr_b64}"
elif isinstance(v, Path):
return FileResponse(v)
else:
Expand All @@ -37,6 +43,9 @@ def decode_item(v: Any) -> Any:
return [decode_item(x) for x in v]
elif isinstance(v, str) and v.startswith("data:image/"):
return base64_str_to_image(v)
elif isinstance(v, str) and v.startswith("data:application/numpy;base64,"):
arr_b64 = v[len("data:application/numpy;base64,") :]
return msgpack.unpackb(base64.b64decode(arr_b64), raw=False)
else:
return v

Expand Down
2 changes: 1 addition & 1 deletion nos/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.0"
__version__ = "0.3.1"
2 changes: 2 additions & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ gitpython
grpcio-tools<=1.49.1
humanize
loguru>=0.7.0
msgpack
msgpack-numpy
opencv-python-headless>=4.6.0.66
pandas
Pillow
Expand Down
41 changes: 41 additions & 0 deletions tests/server/test_server_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import logging
from pathlib import Path

import numpy as np
import psutil
import pytest
from fastapi.responses import FileResponse
from PIL import Image

import nos
from nos.common.system import has_docker
from nos.server import InferenceServiceRuntime
from nos.server.http._utils import decode_item, encode_item
from nos.test.utils import AVAILABLE_RUNTIMES


Expand All @@ -14,6 +19,42 @@
pytestmark = pytest.mark.skipif(not has_docker() or NUM_CPUS < 4, reason="docker is not installed")


def test_encode_decode_item():
# Test encoding a dictionary
input_dict = {"key1": "value1", "key2": "value2"}
expected_dict = {"key1": "value1", "key2": "value2"}
assert encode_item(input_dict) == expected_dict

# Test encoding a list
input_list = [1, 2, 3]
expected_list = [1, 2, 3]
assert encode_item(input_list) == expected_list

# Test encoding a tuple
input_tuple = (4, 5, 6)
expected_tuple = [4, 5, 6]
assert encode_item(input_tuple) == expected_tuple

# Test encoding an Image object
input_image = Image.new("RGB", (100, 100))
assert encode_item(input_image).startswith("data:image/")
assert (decode_item(encode_item(input_image)) == input_image).all()

# Test encoding an ndarray object
input_ndarray = np.array([1, 2, 3])
expected_ndarray = [1, 2, 3]
assert encode_item(input_ndarray) == expected_ndarray

# Test encoding a 3D ndarray object
input_ndarray = np.random.rand(3, 3, 3)
assert encode_item(input_ndarray).startswith("data:application/numpy;base64,")
assert (decode_item(encode_item(input_ndarray)) == input_ndarray).all()

# Test encoding a Path object
input_path = Path("/path/to/file.txt")
assert isinstance(encode_item(input_path), FileResponse)


@pytest.mark.parametrize("runtime", AVAILABLE_RUNTIMES)
def test_nos_init(runtime):
"""Test the NOS server daemon initialization.
Expand Down

0 comments on commit fbaacbd

Please sign in to comment.