-
Notifications
You must be signed in to change notification settings - Fork 0
/
prediction.py
22 lines (17 loc) · 817 Bytes
/
prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import os
import pixellib
import matplotlib.pyplot as plt
from pixellib.instance import instance_segmentation
DATA_URL = "mask_rcnn_coco.h5"
def prediction(img_file):
# instantiating the instance segmentation class
segment_image = instance_segmentation()
# loading the model mask rcnn trained on coco dataset
segment_image.load_model(DATA_URL)
# performing the segmentation on the input image
segment_image.segmentImage(img_file, output_image_name = "output_images/out.jpg")
out = plt.imread("output_images/out.jpg", 0)
# performing the segmentation on the input image with bounding boxes
segment_image.segmentImage(img_file, output_image_name = "output_images/out_box.jpg", show_bboxes = True)
out_box = plt.imread("output_images/out_box.jpg", 0)
return out, out_box