From cc4fcb8b149e24338059f88fe539aba8f6a883f1 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 6 Nov 2023 17:35:02 +0800 Subject: [PATCH] added python sdk --- Taskfile.yaml | 3 ++ app/backend/segmentation.go | 7 ++-- app/backend/track.go | 4 +- cmd/sam/action/hit.go | 2 +- module/sam/main.go | 2 +- module/sam/server/decoder.go | 2 +- module/sam/server/embed.go | 2 +- module/sam/server/introspect.go | 2 +- module/sam/server/new.go | 2 +- module/track/main.go | 2 +- module/track/server/new.go | 2 +- module/track/server/stream.go | 2 +- module/track/server/track.go | 4 +- proto/buf.gen.service.yaml | 18 +++++++-- python/.gitignore | 5 +++ python/README.md | 7 ++++ python/develop.md | 8 ++++ python/nutsh/__init__.py | 4 ++ python/nutsh/lib/__init__.py | 0 python/nutsh/lib/image.py | 66 +++++++++++++++++++++++++++++++++ python/nutsh/lib/logging.py | 10 +++++ python/nutsh/track.py | 37 ++++++++++++++++++ python/nutsh/track_grpc.py | 56 ++++++++++++++++++++++++++++ python/requirements.dev.txt | 4 ++ python/setup.py | 21 +++++++++++ task/python.yaml | 23 ++++++++++++ 26 files changed, 276 insertions(+), 19 deletions(-) create mode 100644 python/.gitignore create mode 100644 python/README.md create mode 100644 python/develop.md create mode 100644 python/nutsh/__init__.py create mode 100644 python/nutsh/lib/__init__.py create mode 100644 python/nutsh/lib/image.py create mode 100644 python/nutsh/lib/logging.py create mode 100644 python/nutsh/track.py create mode 100644 python/nutsh/track_grpc.py create mode 100644 python/requirements.dev.txt create mode 100644 python/setup.py create mode 100644 task/python.yaml diff --git a/Taskfile.yaml b/Taskfile.yaml index ed3ccb7..d1c9961 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -21,6 +21,7 @@ includes: deploy: task/deploy.yaml proto: task/proto.yaml docs: task/docs.yaml + python: task/python.yaml sam: task/sam track: task/track @@ -39,6 +40,8 @@ tasks: - rm -rf app/frontend/src/proto - task: proto:schema - task: proto:service + # move to the Python SDK folder + - task: python:proto build: cmds: diff --git a/app/backend/segmentation.go b/app/backend/segmentation.go index 0c7ff7a..5adb8d1 100644 --- a/app/backend/segmentation.go +++ b/app/backend/segmentation.go @@ -7,14 +7,15 @@ import ( "fmt" "io" "net/http" - "nutsh/openapi/gen/nutshapi" - schemav1 "nutsh/proto/gen/schema/v1" - servicev1 "nutsh/proto/gen/service/v1" "os" "path" "path/filepath" "strings" + "nutsh/openapi/gen/nutshapi" + schemav1 "nutsh/proto/gen/go/schema/v1" + servicev1 "nutsh/proto/gen/go/service/v1" + "github.com/labstack/echo/v4" "github.com/pkg/errors" "go.uber.org/zap" diff --git a/app/backend/track.go b/app/backend/track.go index 754a373..8e47bee 100644 --- a/app/backend/track.go +++ b/app/backend/track.go @@ -12,8 +12,8 @@ import ( "strings" "nutsh/openapi/gen/nutshapi" - schemav1 "nutsh/proto/gen/schema/v1" - servicev1 "nutsh/proto/gen/service/v1" + schemav1 "nutsh/proto/gen/go/schema/v1" + servicev1 "nutsh/proto/gen/go/service/v1" "github.com/labstack/echo/v4" "github.com/pkg/errors" diff --git a/cmd/sam/action/hit.go b/cmd/sam/action/hit.go index c93e842..a746693 100644 --- a/cmd/sam/action/hit.go +++ b/cmd/sam/action/hit.go @@ -12,7 +12,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - servicev1 "nutsh/proto/gen/service/v1" + servicev1 "nutsh/proto/gen/go/service/v1" ) func Hit(ctx *cli.Context) error { diff --git a/module/sam/main.go b/module/sam/main.go index 4a55bb9..88441a2 100644 --- a/module/sam/main.go +++ b/module/sam/main.go @@ -14,7 +14,7 @@ import ( "nutsh/module/common" "nutsh/module/sam/server" - servicev1 "nutsh/proto/gen/service/v1" + servicev1 "nutsh/proto/gen/go/service/v1" ) //go:embed script/* diff --git a/module/sam/server/decoder.go b/module/sam/server/decoder.go index d394cd3..c5a6a0a 100644 --- a/module/sam/server/decoder.go +++ b/module/sam/server/decoder.go @@ -2,7 +2,7 @@ package server import ( "context" - servicev1 "nutsh/proto/gen/service/v1" + servicev1 "nutsh/proto/gen/go/service/v1" "os" "github.com/pkg/errors" diff --git a/module/sam/server/embed.go b/module/sam/server/embed.go index 46a525a..2c800c3 100644 --- a/module/sam/server/embed.go +++ b/module/sam/server/embed.go @@ -14,7 +14,7 @@ import ( "go.uber.org/zap" "nutsh/module/common" - servicev1 "nutsh/proto/gen/service/v1" + servicev1 "nutsh/proto/gen/go/service/v1" ) func (s *mServer) EmbedImage(ctx context.Context, req *servicev1.EmbedImageRequest) (*servicev1.EmbedImageResponse, error) { diff --git a/module/sam/server/introspect.go b/module/sam/server/introspect.go index 46d2ac5..2b67d06 100644 --- a/module/sam/server/introspect.go +++ b/module/sam/server/introspect.go @@ -2,7 +2,7 @@ package server import ( "context" - servicev1 "nutsh/proto/gen/service/v1" + servicev1 "nutsh/proto/gen/go/service/v1" "github.com/pkg/errors" "go.uber.org/zap" diff --git a/module/sam/server/new.go b/module/sam/server/new.go index cfa2355..129f406 100644 --- a/module/sam/server/new.go +++ b/module/sam/server/new.go @@ -2,7 +2,7 @@ package server import ( "fmt" - servicev1 "nutsh/proto/gen/service/v1" + servicev1 "nutsh/proto/gen/go/service/v1" "go.uber.org/zap" ) diff --git a/module/track/main.go b/module/track/main.go index 8e281bd..21022e4 100644 --- a/module/track/main.go +++ b/module/track/main.go @@ -12,7 +12,7 @@ import ( "google.golang.org/grpc" "nutsh/module/track/server" - servicev1 "nutsh/proto/gen/service/v1" + servicev1 "nutsh/proto/gen/go/service/v1" ) func main() { diff --git a/module/track/server/new.go b/module/track/server/new.go index 3e7d665..92cb5be 100644 --- a/module/track/server/new.go +++ b/module/track/server/new.go @@ -1,7 +1,7 @@ package server import ( - servicev1 "nutsh/proto/gen/service/v1" + servicev1 "nutsh/proto/gen/go/service/v1" ) func New(opts ...Option) (servicev1.TrackServiceServer, func()) { diff --git a/module/track/server/stream.go b/module/track/server/stream.go index 7ded205..b9aced8 100644 --- a/module/track/server/stream.go +++ b/module/track/server/stream.go @@ -9,7 +9,7 @@ import ( "os/exec" "path/filepath" - servicev1 "nutsh/proto/gen/service/v1" + servicev1 "nutsh/proto/gen/go/service/v1" "github.com/pkg/errors" "go.uber.org/zap" diff --git a/module/track/server/track.go b/module/track/server/track.go index 66f3637..f732f11 100644 --- a/module/track/server/track.go +++ b/module/track/server/track.go @@ -19,8 +19,8 @@ import ( "go.uber.org/zap" "nutsh/module/common" - schemav1 "nutsh/proto/gen/schema/v1" - servicev1 "nutsh/proto/gen/service/v1" + schemav1 "nutsh/proto/gen/go/schema/v1" + servicev1 "nutsh/proto/gen/go/service/v1" ) func (s *mServer) Track(ctx context.Context, req *servicev1.TrackRequest) (*servicev1.TrackResponse, error) { diff --git a/proto/buf.gen.service.yaml b/proto/buf.gen.service.yaml index f44f560..1a527e9 100644 --- a/proto/buf.gen.service.yaml +++ b/proto/buf.gen.service.yaml @@ -2,13 +2,25 @@ version: v1 managed: enabled: true go_package_prefix: - default: nutsh/proto/gen + default: nutsh/proto/gen/go plugins: - plugin: go - out: gen + out: gen/go opt: paths=source_relative - plugin: go-grpc - out: gen + out: gen/go opt: - paths=source_relative - require_unimplemented_servers=false + # Generate Python code. + # https://buf.build/grpc/python + - plugin: buf.build/grpc/python:v1.59.2 + out: gen/python + - plugin: buf.build/protocolbuffers/python + out: gen/python + # At the time writing, the official gRPC generation does not provide Python type hints. + # To rescue, we use a buf.build plugin to do so. + # https://github.com/grpc/grpc/issues/29041 + # https://buf.build/protocolbuffers/pyi + - plugin: buf.build/protocolbuffers/pyi:v25.0 + out: gen/python diff --git a/python/.gitignore b/python/.gitignore new file mode 100644 index 0000000..4757543 --- /dev/null +++ b/python/.gitignore @@ -0,0 +1,5 @@ +proto +*.egg-info +.mypy_cache +*.pyi +dist \ No newline at end of file diff --git a/python/README.md b/python/README.md new file mode 100644 index 0000000..18b1ee8 --- /dev/null +++ b/python/README.md @@ -0,0 +1,7 @@ +# The Python SDK for nutsh + +Installation: + +``` +pip install nutsh +``` diff --git a/python/develop.md b/python/develop.md new file mode 100644 index 0000000..29e3d73 --- /dev/null +++ b/python/develop.md @@ -0,0 +1,8 @@ +To setup the develop environment: + +``` +conda create --name nutsh-py-sdk python=3.10 +conda activate nutsh-py-sdk +pip install ./python +pip install -r python/requirements.dev.txt +``` diff --git a/python/nutsh/__init__.py b/python/nutsh/__init__.py new file mode 100644 index 0000000..9b0852b --- /dev/null +++ b/python/nutsh/__init__.py @@ -0,0 +1,4 @@ +import os +import sys +proto_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'proto')) +sys.path.append(proto_path) \ No newline at end of file diff --git a/python/nutsh/lib/__init__.py b/python/nutsh/lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/nutsh/lib/image.py b/python/nutsh/lib/image.py new file mode 100644 index 0000000..ad52216 --- /dev/null +++ b/python/nutsh/lib/image.py @@ -0,0 +1,66 @@ +import os +import requests +import hashlib +import base64 +import logging +import concurrent.futures +from urllib.parse import urlparse +from typing import List + +# assuming logging configuration is already set somewhere +logger = logging.getLogger(__name__) + +base64_jpeg_prefix = "data:image/jpeg;base64," +base64_png_prefix = "data:image/png;base64," + + +class ImagePreparer: + def __init__(self, uri: str): + self.uri = uri + + def prepare(self, save_dir: str): + name = hashlib.md5(self.uri.encode('utf-8')).hexdigest() + self._get_extension() + path = os.path.join(save_dir, name) + + if os.path.exists(path): + return path + if self.uri.startswith((base64_jpeg_prefix, base64_png_prefix)): + return self._save_base64(path) + return self._download(path) + + def _get_extension(self): + if self.uri.startswith(base64_jpeg_prefix): + return ".jpg" + if self.uri.startswith(base64_png_prefix): + return ".png" + return os.path.splitext(urlparse(self.uri).path)[1] + + def _save_base64(self, save_path: str): + logger.info("saving base64 image to path: %s", save_path) + prefix = base64_jpeg_prefix if self.uri.startswith(base64_jpeg_prefix) else base64_png_prefix + im_base64 = self.uri[len(prefix):] + image_data = base64.b64decode(im_base64) + with open(save_path, 'wb') as f: + f.write(image_data) + return save_path + + def _download(self, save_path: str): + logger.info("downloading image from url: %s to path: %s", self.uri, save_path) + response = requests.get(self.uri, stream=True) + if response.status_code != 200: + raise ValueError("failed to download {}, status code: {}".format(self.uri, response.status_code)) + with open(save_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + return save_path + + +def prepare_images(save_dir: str, im_uris: List[str]): + num_workers = os.cpu_count() + paths: List[str] = [""] * len(im_uris) + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + future_to_index = {executor.submit(ImagePreparer(uri).prepare, save_dir): idx for idx, uri in enumerate(im_uris)} + for future in concurrent.futures.as_completed(future_to_index): + index = future_to_index[future] + paths[index] = future.result() + return paths diff --git a/python/nutsh/lib/logging.py b/python/nutsh/lib/logging.py new file mode 100644 index 0000000..1b8b7a8 --- /dev/null +++ b/python/nutsh/lib/logging.py @@ -0,0 +1,10 @@ +import os +import logging + +def init_logging(): + log_level = os.environ.get('LOG_LEVEL', 'INFO') + logging.basicConfig( + level=getattr(logging, log_level, None), + datefmt="%Y-%m-%d %H:%M:%S", + format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", + ) diff --git a/python/nutsh/track.py b/python/nutsh/track.py new file mode 100644 index 0000000..53e36d2 --- /dev/null +++ b/python/nutsh/track.py @@ -0,0 +1,37 @@ +import os +import logging +from typing import Callable + +import grpc +from concurrent import futures + +from .lib.logging import init_logging +from .track_grpc import TrackService, Tracker +from .proto.schema.v1.train_pb2 import Mask +from .proto.service.v1 import track_pb2_grpc + +class Service: + def __init__(self, workspace: str, on_reqeust: Callable[[str, Mask], Tracker]): + self.workspace = workspace + self.on_reqeust = on_reqeust + + def start(self, port: int) -> None: + init_logging() + + service = TrackService( + workspace=self.workspace, + on_reqeust=self.on_reqeust + ) + + logging.info(f"track stream service starts listening on :{port}") + num_workers = os.cpu_count() + server = grpc.server( # type: ignore + futures.ThreadPoolExecutor(max_workers=num_workers), + options=[ + ('grpc.max_receive_message_length', 1024 * 1024 * 1024) # allows 1K images each being of 1MB + ] + ) + track_pb2_grpc.add_TrackServiceServicer_to_server(service, server) # type: ignore + server.add_insecure_port(f'[::]:{port}') + server.start() + server.wait_for_termination() diff --git a/python/nutsh/track_grpc.py b/python/nutsh/track_grpc.py new file mode 100644 index 0000000..53701af --- /dev/null +++ b/python/nutsh/track_grpc.py @@ -0,0 +1,56 @@ +import os +import logging +from abc import ABC, abstractmethod +from typing import Callable, Iterator, List + +from grpc import ServicerContext, StatusCode + +from .proto.service.v1 import track_pb2_grpc +from .proto.service.v1.track_pb2 import FrameMask, TrackRequest, TrackResponse +from .proto.schema.v1.train_pb2 import Mask +from .lib.image import prepare_images + +class Tracker(ABC): + @abstractmethod + def predict(self, im_path: str) -> Mask: + """Predict on the given image and return a mask.""" + pass + + +class TrackService(track_pb2_grpc.TrackServiceServicer): + def __init__(self, workspace: str, on_reqeust: Callable[[str, Mask], Tracker]): + self.workspace = workspace + self.on_reqeust = on_reqeust + + def Track(self, request: TrackRequest, context: ServicerContext) -> TrackResponse: + logging.info("received a Track request") + im_paths = self._prepare_images(request=request, context=context) + + # predict + tracker = self.on_reqeust(im_paths[0], request.first_image_mask) + masks = [tracker.predict(im_path) for im_path in im_paths[1:]] + + return TrackResponse(subsequent_image_masks=masks) + + def TrackStream(self, request: TrackRequest, context: ServicerContext) -> Iterator[FrameMask]: + logging.info("received a TrackStream request") + im_paths = self._prepare_images(request=request, context=context) + + # predict + tracker = self.on_reqeust(im_paths[0], request.first_image_mask) + for i, im_path in enumerate(im_paths[1:]): + mask = tracker.predict(im_path) + yield FrameMask(frame_index=i+1, mask=mask) + + def _prepare_images(self, request: TrackRequest, context: ServicerContext) -> List[str]: + im_uris = [request.first_image_uri, *request.subsequent_image_uris] + im_dir = os.path.join(self.workspace, "images") + if not os.path.exists(im_dir): + os.makedirs(im_dir) + try: + im_paths = prepare_images(im_dir, im_uris) + except Exception as e: + logging.error(f"failed to prepare images: {e}") + context.abort_with_status(StatusCode.INTERNAL.value) + + return im_paths \ No newline at end of file diff --git a/python/requirements.dev.txt b/python/requirements.dev.txt new file mode 100644 index 0000000..32f135b --- /dev/null +++ b/python/requirements.dev.txt @@ -0,0 +1,4 @@ +setuptools +mypy +build +twine \ No newline at end of file diff --git a/python/setup.py b/python/setup.py new file mode 100644 index 0000000..a8d3554 --- /dev/null +++ b/python/setup.py @@ -0,0 +1,21 @@ +from setuptools import setup, find_packages + +setup( + name='nutsh', + version='0.0.1-4', + url='https://nutsh.ai/', + author='Xu Han', + author_email='hxhxhx88@gmail.com', + description='The Python SDK of nutsh.ai', + long_description=open('README.md').read(), + long_description_content_type='text/markdown', + packages=find_packages(), + package_data={ + '': ['*.pyi', 'requirements.txt'], # include all .pyi files + }, + install_requires=[ + "grpcio-tools==1.59.2", + "grpc-stubs==1.53.0.3", + "requests==2.31.0" + ], +) \ No newline at end of file diff --git a/task/python.yaml b/task/python.yaml new file mode 100644 index 0000000..eceefc5 --- /dev/null +++ b/task/python.yaml @@ -0,0 +1,23 @@ +version: "3" + +tasks: + stub: + cmds: + - stubgen nutsh/track.py -o . + dir: python + + proto: + cmds: + - rm -rf python/nutsh/proto + - cp -r proto/gen/python python/nutsh/proto + - find python/nutsh/proto -type d -exec touch {}/__init__.py \; + + dist: + cmds: + - task: proto + - task: stub + - rm -rf dist + - python3 -m build --sdist + - twine check dist/* + - twine upload dist/* + dir: python