-
Notifications
You must be signed in to change notification settings - Fork 0
/
api.py
69 lines (59 loc) · 2.17 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from fastapi import FastAPI, Query
from demo import init_detector, inference_detector
import uvicorn
from fastapi import File, UploadFile
import os.path as osp
import os
import numpy as np
import mmcv
import cv2
import uuid
from fastapi.staticfiles import StaticFiles
from enum import Enum
router = FastAPI()
model = None
router.mount("/static", StaticFiles(directory="demo_images"), name="/static")
def load_model(config, checkpoint):
device = "cuda:1"
# Load model
model = init_detector(config=config, checkpoint=checkpoint, device=device)
return model
print("Loading Kfiou: ")
kfiou = load_model(
"/mlcv1/WorkingSpace/Personal/hienht/Dense/mmrotate/train_fliou/r3det_kfiou_ln_swin_tiny_adamw_fpn_1x_dota_ms_rr_oc_2.py",
"/mlcv1/WorkingSpace/Personal/hienht/Dense/mmrotate/train_fliou/epoch_50.pth"
)
@router.get("/")
async def getRoot():
return {"message": "Hello World"}
@router.post("/predict")
async def predict(image: UploadFile = File(...)):
global kfiou
image.filename = f"{uuid.uuid4()}.png"
with open(f"temp/{image.filename}", "wb") as f:
f.write(await image.read())
model = kfiou
result = inference_detector(model, osp.join("temp", image.filename))
#show the results
polygons = []
for i, bbox in enumerate(result[0]):
if bbox[-1] < 0.3:
continue
xc, yc, w, h, ag = bbox[:5]
wx, wy = w / 2 * np.cos(ag), w / 2 * np.sin(ag)
hx, hy = -h / 2 * np.sin(ag), h / 2 * np.cos(ag)
p1 = (xc - wx - hx, yc - wy - hy)
p2 = (xc + wx - hx, yc + wy - hy)
p3 = (xc + wx + hx, yc + wy + hy)
p4 = (xc - wx + hx, yc - wy + hy)
poly = np.int0(np.array([p1, p2, p3, p4]))
polygons.append((poly))
img = mmcv.imread(osp.join("temp", image.filename))
os.remove(osp.join("temp", image.filename))
for poly in polygons:
img = cv2.polylines(img, [poly], True, (0, 0, 255), 2)
result_path = "demo_images/"
mmcv.imwrite(img, osp.join(result_path, image.filename))
return {"image_result": osp.join(result_path, image.filename)}
if __name__ == "__main__":
uvicorn.run("api:router", host="0.0.0.0", port=35000, reload=True)