-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
app.py
98 lines (73 loc) · 2.93 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# -*- coding: utf-8 -*-
# Copyright (c) Louis Brulé Naudet. All Rights Reserved.
# This software may be used and distributed according to the terms of License Agreement.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gradio as gr
try:
import detectron2
except:
import os
os.system('pip install git+https://github.com/louisbrulenaudet/detectron2')
import cv2
import torch
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
model = {
"name": "Sealife detection",
"model_path": "./model_final.pth",
"classes": ["creatures", "fish", "jellyfish", "penguin", "puffin", "shark", "starfish", "stingray"],
"cfg": None,
"metadata": None
}
model["cfg"] = get_cfg()
model["cfg"].merge_from_file("./configs/faster_rcnn_R_50_FPN_3x.yaml")
model["cfg"].MODEL.ROI_HEADS.NUM_CLASSES = len(model["classes"])
model["cfg"].MODEL.WEIGHTS = model["model_path"]
model["metadata"] = MetadataCatalog.get(model["name"])
model["metadata"].thing_classes = model["classes"]
if not torch.cuda.is_available():
model["cfg"].MODEL.DEVICE = "cpu"
def inference(image, threshold):
"""
Perform inference using a provided image and a pre-trained model, generating visual predictions.
Parameters
----------
image : numpy.ndarray
The input image for which inference needs to be performed. It should be in BGR format.
min_score : float
The minimum confidence score threshold for predictions.
model : detectron2.engine.DefaultPredictor
The pre-trained model used for inference.
Returns
-------
numpy.ndarray
An image with drawn instance predictions generated by the model.
"""
global model
# Model expects BGR
im = image[:, :, ::-1]
# Set score threshold
model["cfg"].MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
# Use the provided model
predictor = DefaultPredictor(model["cfg"])
outputs = predictor(im)
v = Visualizer(im, model["metadata"], scale=1.2)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
return out.get_image()[:, :, ::-1]
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_input = gr.Image(type="numpy", label="Input Image")
threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Threshold value")
inference_button = gr.Button("Submit")
image_output = gr.Image(type="pil", label="Output")
inference_button.click(fn=inference, inputs=[image_input, threshold], outputs=image_output)
demo.launch()