Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2+ dimensional numpy array serde with msgpack #551

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion nos/server/http/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
from pathlib import Path
from typing import Any, Dict

import msgpack
import msgpack_numpy as m
import numpy as np
from fastapi import UploadFile
from fastapi.responses import FileResponse
from PIL import Image


m.patch()


def encode_item(v: Any) -> Any:
"""Encode an item to a JSON-serializable object."""
if isinstance(v, dict):
Expand All @@ -22,7 +27,8 @@ def encode_item(v: Any) -> Any:
if v.ndim <= 2:
return v.tolist()
else:
raise ValueError(f"Unsupported ndarray dimension: {v.ndim}")
arr_b64 = base64.b64encode(msgpack.packb(v)).decode()
return f"data:application/numpy;base64,{arr_b64}"
elif isinstance(v, Path):
return FileResponse(v)
else:
Expand All @@ -37,6 +43,9 @@ def decode_item(v: Any) -> Any:
return [decode_item(x) for x in v]
elif isinstance(v, str) and v.startswith("data:image/"):
return base64_str_to_image(v)
elif isinstance(v, str) and v.startswith("data:application/numpy;base64,"):
arr_b64 = v[len("data:application/numpy;base64,") :]
return msgpack.unpackb(base64.b64decode(arr_b64), raw=False)
else:
return v

Expand Down
2 changes: 1 addition & 1 deletion nos/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.0"
__version__ = "0.3.1"
2 changes: 2 additions & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ gitpython
grpcio-tools<=1.49.1
humanize
loguru>=0.7.0
msgpack
msgpack-numpy
opencv-python-headless>=4.6.0.66
pandas
Pillow
Expand Down
41 changes: 41 additions & 0 deletions tests/server/test_server_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import logging
from pathlib import Path

import numpy as np
import psutil
import pytest
from fastapi.responses import FileResponse
from PIL import Image

import nos
from nos.common.system import has_docker
from nos.server import InferenceServiceRuntime
from nos.server.http._utils import decode_item, encode_item
from nos.test.utils import AVAILABLE_RUNTIMES


Expand All @@ -14,6 +19,42 @@
pytestmark = pytest.mark.skipif(not has_docker() or NUM_CPUS < 4, reason="docker is not installed")


def test_encode_decode_item():
# Test encoding a dictionary
input_dict = {"key1": "value1", "key2": "value2"}
expected_dict = {"key1": "value1", "key2": "value2"}
assert encode_item(input_dict) == expected_dict

# Test encoding a list
input_list = [1, 2, 3]
expected_list = [1, 2, 3]
assert encode_item(input_list) == expected_list

# Test encoding a tuple
input_tuple = (4, 5, 6)
expected_tuple = [4, 5, 6]
assert encode_item(input_tuple) == expected_tuple

# Test encoding an Image object
input_image = Image.new("RGB", (100, 100))
assert encode_item(input_image).startswith("data:image/")
assert (decode_item(encode_item(input_image)) == input_image).all()

# Test encoding an ndarray object
input_ndarray = np.array([1, 2, 3])
expected_ndarray = [1, 2, 3]
assert encode_item(input_ndarray) == expected_ndarray

# Test encoding a 3D ndarray object
input_ndarray = np.random.rand(3, 3, 3)
assert encode_item(input_ndarray).startswith("data:application/numpy;base64,")
assert (decode_item(encode_item(input_ndarray)) == input_ndarray).all()

# Test encoding a Path object
input_path = Path("/path/to/file.txt")
assert isinstance(encode_item(input_path), FileResponse)


@pytest.mark.parametrize("runtime", AVAILABLE_RUNTIMES)
def test_nos_init(runtime):
"""Test the NOS server daemon initialization.
Expand Down
Loading