Skip to content

Commit

Permalink
adding example
Browse files Browse the repository at this point in the history
  • Loading branch information
hxhxhx88 committed Nov 12, 2023
1 parent cc4fcb8 commit 7d6ccec
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 0 deletions.
1 change: 1 addition & 0 deletions example/track/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vendor
51 changes: 51 additions & 0 deletions example/track/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
An example implementation of a tracking service using nutsh Python SDK.

# Runtime Environment

Take the following steps as an example to setup the runtim environment.

1. Prepare a conda virtual environment.

```
conda create --name nutsh-track python=3.10 -y && \
conda activate nutsh-track
```

2. Install `torch` and `torchvision` following [the official guide](https://pytorch.org/get-started/locally/).

3. Install other dependencies.

```
pip install -r requirements.txt
```

4. Download the [Segment-and-Track-Anything](https://github.com/z-x-yang/Segment-and-Track-Anything.git) model and extract the useful part for us.

```
git clone https://github.com/z-x-yang/Segment-and-Track-Anything.git vendor/Segment-and-Track-Anything && \
cd vendor/Segment-and-Track-Anything && git checkout 77354b5 && cd - && \
mv vendor/Segment-and-Track-Anything/aot vendor/aot && \
rm -rf vendor/Segment-and-Track-Anything
```

5. Install [Pytorch-Correlation-extension](https://github.com/ClementPinard/Pytorch-Correlation-extension.git) from its source code.

```
git clone https://github.com/ClementPinard/Pytorch-Correlation-extension.git vendor/Pytorch-Correlation-extension && \
cd vendor/Pytorch-Correlation-extension && \
git checkout 14a159e && \
python setup.py install && cd -
```

6. Download the model checkpoints.

```
mkdir -p local/ckpt && \
gdown --id '1QoChMkTVxdYZ_eBlZhK2acq9KMQZccPJ' --output local/ckpt/R50_DeAOTL_PRE_YTB_DAV.pth
```

7. Install nutsh Python SDK.

```
pip install nutsh
```
4 changes: 4 additions & 0 deletions example/track/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
opencv-python==4.8.1.78
pycocotools==2.0.7
scikit-image==0.22.0
gdown==4.7.1
72 changes: 72 additions & 0 deletions example/track/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import argparse
import cv2
import gc
import json
import logging
import numpy as np
import os
import sys
import torch
from pycocotools import mask as coco_mask
from tracker_nutsh import NutshTracker

from nutsh.track import Service, Tracker
from nutsh.proto.schema.v1.train_pb2 import Mask

class TrackerImpl(Tracker):
def __init__(self, tracker):
self.tracker = tracker

def predict(self, im_path: str) -> Mask:
torch.cuda.empty_cache()
gc.collect()

frame = cv2.imread(im_path)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pred_mask = self.tracker.track(frame, update_memory=True)

rle = coco_mask.encode(np.asfortranarray(pred_mask))
return {
"coco_encoded_rle": rle['counts'].decode('utf-8'),
"size": {
"width": rle["size"][1],
"height": rle["size"][0],
}
}

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=6)
parser.add_argument('--model-path', type=str, default="local/ckpt/R50_DeAOTL_PRE_YTB_DAV.pth")
parser.add_argument('--workspace', type=str, default="local/workspace")
parser.add_argument('--port', type=int, default=12348)
args = parser.parse_args()

def new_tracker(first_image_path: str, first_image_mask: Mask) -> Tracker:
aot_tracker = NutshTracker({
'gpu_id': args.gpu,
'model_path': args.model_path,
'phase': 'PRE_YTB_DAV',
'model': 'r50_deaotl',
'long_term_mem_gap': 9999,
'max_len_long_term': 9999,
})
aot_tracker.restart_tracker()

# read first frame mask
m = first_image_mask
rle_str, mw, mh = m.coco_encoded_rle, m.size.width, m.size.height
mask = coco_mask.decode({"size": [mh, mw], "counts": rle_str})

# read first frame
first_frame = cv2.imread(first_image_path)
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
aot_tracker.add_reference(frame=first_frame, mask=mask)

return TrackerImpl(tracker=aot_tracker)

ser = Service(workspace=args.workspace, on_reqeust=new_tracker)
ser.start(args.port)

if __name__ == "__main__":
main()

0 comments on commit 7d6ccec

Please sign in to comment.