Skip to content

Commit

Permalink
added python sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
hxhxhx88 committed Nov 6, 2023
1 parent 841eab3 commit cc4fcb8
Show file tree
Hide file tree
Showing 26 changed files with 276 additions and 19 deletions.
3 changes: 3 additions & 0 deletions Taskfile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ includes:
deploy: task/deploy.yaml
proto: task/proto.yaml
docs: task/docs.yaml
python: task/python.yaml
sam: task/sam
track: task/track

Expand All @@ -39,6 +40,8 @@ tasks:
- rm -rf app/frontend/src/proto
- task: proto:schema
- task: proto:service
# move to the Python SDK folder
- task: python:proto

build:
cmds:
Expand Down
7 changes: 4 additions & 3 deletions app/backend/segmentation.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ import (
"fmt"
"io"
"net/http"
"nutsh/openapi/gen/nutshapi"
schemav1 "nutsh/proto/gen/schema/v1"
servicev1 "nutsh/proto/gen/service/v1"
"os"
"path"
"path/filepath"
"strings"

"nutsh/openapi/gen/nutshapi"
schemav1 "nutsh/proto/gen/go/schema/v1"
servicev1 "nutsh/proto/gen/go/service/v1"

"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"go.uber.org/zap"
Expand Down
4 changes: 2 additions & 2 deletions app/backend/track.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
"strings"

"nutsh/openapi/gen/nutshapi"
schemav1 "nutsh/proto/gen/schema/v1"
servicev1 "nutsh/proto/gen/service/v1"
schemav1 "nutsh/proto/gen/go/schema/v1"
servicev1 "nutsh/proto/gen/go/service/v1"

"github.com/labstack/echo/v4"
"github.com/pkg/errors"
Expand Down
2 changes: 1 addition & 1 deletion cmd/sam/action/hit.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

servicev1 "nutsh/proto/gen/service/v1"
servicev1 "nutsh/proto/gen/go/service/v1"
)

func Hit(ctx *cli.Context) error {
Expand Down
2 changes: 1 addition & 1 deletion module/sam/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

"nutsh/module/common"
"nutsh/module/sam/server"
servicev1 "nutsh/proto/gen/service/v1"
servicev1 "nutsh/proto/gen/go/service/v1"
)

//go:embed script/*
Expand Down
2 changes: 1 addition & 1 deletion module/sam/server/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package server

import (
"context"
servicev1 "nutsh/proto/gen/service/v1"
servicev1 "nutsh/proto/gen/go/service/v1"
"os"

"github.com/pkg/errors"
Expand Down
2 changes: 1 addition & 1 deletion module/sam/server/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"go.uber.org/zap"

"nutsh/module/common"
servicev1 "nutsh/proto/gen/service/v1"
servicev1 "nutsh/proto/gen/go/service/v1"
)

func (s *mServer) EmbedImage(ctx context.Context, req *servicev1.EmbedImageRequest) (*servicev1.EmbedImageResponse, error) {
Expand Down
2 changes: 1 addition & 1 deletion module/sam/server/introspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package server

import (
"context"
servicev1 "nutsh/proto/gen/service/v1"
servicev1 "nutsh/proto/gen/go/service/v1"

"github.com/pkg/errors"
"go.uber.org/zap"
Expand Down
2 changes: 1 addition & 1 deletion module/sam/server/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package server

import (
"fmt"
servicev1 "nutsh/proto/gen/service/v1"
servicev1 "nutsh/proto/gen/go/service/v1"

"go.uber.org/zap"
)
Expand Down
2 changes: 1 addition & 1 deletion module/track/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"google.golang.org/grpc"

"nutsh/module/track/server"
servicev1 "nutsh/proto/gen/service/v1"
servicev1 "nutsh/proto/gen/go/service/v1"
)

func main() {
Expand Down
2 changes: 1 addition & 1 deletion module/track/server/new.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package server

import (
servicev1 "nutsh/proto/gen/service/v1"
servicev1 "nutsh/proto/gen/go/service/v1"
)

func New(opts ...Option) (servicev1.TrackServiceServer, func()) {
Expand Down
2 changes: 1 addition & 1 deletion module/track/server/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"os/exec"
"path/filepath"

servicev1 "nutsh/proto/gen/service/v1"
servicev1 "nutsh/proto/gen/go/service/v1"

"github.com/pkg/errors"
"go.uber.org/zap"
Expand Down
4 changes: 2 additions & 2 deletions module/track/server/track.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import (
"go.uber.org/zap"

"nutsh/module/common"
schemav1 "nutsh/proto/gen/schema/v1"
servicev1 "nutsh/proto/gen/service/v1"
schemav1 "nutsh/proto/gen/go/schema/v1"
servicev1 "nutsh/proto/gen/go/service/v1"
)

func (s *mServer) Track(ctx context.Context, req *servicev1.TrackRequest) (*servicev1.TrackResponse, error) {
Expand Down
18 changes: 15 additions & 3 deletions proto/buf.gen.service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,25 @@ version: v1
managed:
enabled: true
go_package_prefix:
default: nutsh/proto/gen
default: nutsh/proto/gen/go
plugins:
- plugin: go
out: gen
out: gen/go
opt: paths=source_relative
- plugin: go-grpc
out: gen
out: gen/go
opt:
- paths=source_relative
- require_unimplemented_servers=false
# Generate Python code.
# https://buf.build/grpc/python
- plugin: buf.build/grpc/python:v1.59.2
out: gen/python
- plugin: buf.build/protocolbuffers/python
out: gen/python
# At the time writing, the official gRPC generation does not provide Python type hints.
# To rescue, we use a buf.build plugin to do so.
# https://github.com/grpc/grpc/issues/29041
# https://buf.build/protocolbuffers/pyi
- plugin: buf.build/protocolbuffers/pyi:v25.0
out: gen/python
5 changes: 5 additions & 0 deletions python/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
proto
*.egg-info
.mypy_cache
*.pyi
dist
7 changes: 7 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# The Python SDK for nutsh

Installation:

```
pip install nutsh
```
8 changes: 8 additions & 0 deletions python/develop.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
To setup the develop environment:

```
conda create --name nutsh-py-sdk python=3.10
conda activate nutsh-py-sdk
pip install ./python
pip install -r python/requirements.dev.txt
```
4 changes: 4 additions & 0 deletions python/nutsh/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import os
import sys
proto_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'proto'))
sys.path.append(proto_path)
Empty file added python/nutsh/lib/__init__.py
Empty file.
66 changes: 66 additions & 0 deletions python/nutsh/lib/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import requests
import hashlib
import base64
import logging
import concurrent.futures
from urllib.parse import urlparse
from typing import List

# assuming logging configuration is already set somewhere
logger = logging.getLogger(__name__)

base64_jpeg_prefix = "data:image/jpeg;base64,"
base64_png_prefix = "data:image/png;base64,"


class ImagePreparer:
def __init__(self, uri: str):
self.uri = uri

def prepare(self, save_dir: str):
name = hashlib.md5(self.uri.encode('utf-8')).hexdigest() + self._get_extension()
path = os.path.join(save_dir, name)

if os.path.exists(path):
return path
if self.uri.startswith((base64_jpeg_prefix, base64_png_prefix)):
return self._save_base64(path)
return self._download(path)

def _get_extension(self):
if self.uri.startswith(base64_jpeg_prefix):
return ".jpg"
if self.uri.startswith(base64_png_prefix):
return ".png"
return os.path.splitext(urlparse(self.uri).path)[1]

def _save_base64(self, save_path: str):
logger.info("saving base64 image to path: %s", save_path)
prefix = base64_jpeg_prefix if self.uri.startswith(base64_jpeg_prefix) else base64_png_prefix
im_base64 = self.uri[len(prefix):]
image_data = base64.b64decode(im_base64)
with open(save_path, 'wb') as f:
f.write(image_data)
return save_path

def _download(self, save_path: str):
logger.info("downloading image from url: %s to path: %s", self.uri, save_path)
response = requests.get(self.uri, stream=True)
if response.status_code != 200:
raise ValueError("failed to download {}, status code: {}".format(self.uri, response.status_code))
with open(save_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return save_path


def prepare_images(save_dir: str, im_uris: List[str]):
num_workers = os.cpu_count()
paths: List[str] = [""] * len(im_uris)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
future_to_index = {executor.submit(ImagePreparer(uri).prepare, save_dir): idx for idx, uri in enumerate(im_uris)}
for future in concurrent.futures.as_completed(future_to_index):
index = future_to_index[future]
paths[index] = future.result()
return paths
10 changes: 10 additions & 0 deletions python/nutsh/lib/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os
import logging

def init_logging():
log_level = os.environ.get('LOG_LEVEL', 'INFO')
logging.basicConfig(
level=getattr(logging, log_level, None),
datefmt="%Y-%m-%d %H:%M:%S",
format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s",
)
37 changes: 37 additions & 0 deletions python/nutsh/track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import logging
from typing import Callable

import grpc
from concurrent import futures

from .lib.logging import init_logging
from .track_grpc import TrackService, Tracker
from .proto.schema.v1.train_pb2 import Mask
from .proto.service.v1 import track_pb2_grpc

class Service:
def __init__(self, workspace: str, on_reqeust: Callable[[str, Mask], Tracker]):
self.workspace = workspace
self.on_reqeust = on_reqeust

def start(self, port: int) -> None:
init_logging()

service = TrackService(
workspace=self.workspace,
on_reqeust=self.on_reqeust
)

logging.info(f"track stream service starts listening on :{port}")
num_workers = os.cpu_count()
server = grpc.server( # type: ignore
futures.ThreadPoolExecutor(max_workers=num_workers),
options=[
('grpc.max_receive_message_length', 1024 * 1024 * 1024) # allows 1K images each being of 1MB
]
)
track_pb2_grpc.add_TrackServiceServicer_to_server(service, server) # type: ignore
server.add_insecure_port(f'[::]:{port}')
server.start()
server.wait_for_termination()
56 changes: 56 additions & 0 deletions python/nutsh/track_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import logging
from abc import ABC, abstractmethod
from typing import Callable, Iterator, List

from grpc import ServicerContext, StatusCode

from .proto.service.v1 import track_pb2_grpc
from .proto.service.v1.track_pb2 import FrameMask, TrackRequest, TrackResponse
from .proto.schema.v1.train_pb2 import Mask
from .lib.image import prepare_images

class Tracker(ABC):
@abstractmethod
def predict(self, im_path: str) -> Mask:
"""Predict on the given image and return a mask."""
pass


class TrackService(track_pb2_grpc.TrackServiceServicer):
def __init__(self, workspace: str, on_reqeust: Callable[[str, Mask], Tracker]):
self.workspace = workspace
self.on_reqeust = on_reqeust

def Track(self, request: TrackRequest, context: ServicerContext) -> TrackResponse:
logging.info("received a Track request")
im_paths = self._prepare_images(request=request, context=context)

# predict
tracker = self.on_reqeust(im_paths[0], request.first_image_mask)
masks = [tracker.predict(im_path) for im_path in im_paths[1:]]

return TrackResponse(subsequent_image_masks=masks)

def TrackStream(self, request: TrackRequest, context: ServicerContext) -> Iterator[FrameMask]:
logging.info("received a TrackStream request")
im_paths = self._prepare_images(request=request, context=context)

# predict
tracker = self.on_reqeust(im_paths[0], request.first_image_mask)
for i, im_path in enumerate(im_paths[1:]):
mask = tracker.predict(im_path)
yield FrameMask(frame_index=i+1, mask=mask)

def _prepare_images(self, request: TrackRequest, context: ServicerContext) -> List[str]:
im_uris = [request.first_image_uri, *request.subsequent_image_uris]
im_dir = os.path.join(self.workspace, "images")
if not os.path.exists(im_dir):
os.makedirs(im_dir)
try:
im_paths = prepare_images(im_dir, im_uris)
except Exception as e:
logging.error(f"failed to prepare images: {e}")
context.abort_with_status(StatusCode.INTERNAL.value)

return im_paths
4 changes: 4 additions & 0 deletions python/requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
setuptools
mypy
build
twine
Loading

0 comments on commit cc4fcb8

Please sign in to comment.