-
Notifications
You must be signed in to change notification settings - Fork 0
/
classification.py
75 lines (71 loc) · 2.88 KB
/
classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import tflite_runtime.interpreter as tflite
import numpy as np
import cv2
def inv_contour(image,mask,x,y,w,h):
inv_mask=mask[y:y+h,x:x+w]^0xFF
inv_mask=inv_mask/255
inv_mask=inv_mask.astype('uint8')
contours, hierarchy = cv2.findContours(inv_mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
mm=0
maxarea=0
for i,cnt in zip(range(0,len(contours)), contours):
if maxarea<cv2.contourArea(cnt):
maxarea=cv2.contourArea(cnt)
mm=i
c0=contours[mm]
X_,Y_,W_,H_=cv2.boundingRect(c0)
return image[y+Y_:y+Y_+H_,x+X_:x+X_+W_,:]
def load_interperter(model_path):
interpreter_=tflite.Interpreter(model_path=model_path)
interpreter_.allocate_tensors()
input_ =interpreter_.tensor(interpreter_.get_input_details()[0]['index'])
output = interpreter_.tensor(interpreter_.get_output_details()[0]["index"])
return {'model':interpreter_,'input':input_,'output':output}
def load_labels(path):
with open(path, 'r') as f:
return {i: line.strip() for i, line in enumerate(f.readlines())}
dic=load_interperter("/home/pi/raspi4withTF/tflitemodel.tflite")
#labels=load_labels("/home/pi/Downloads/labels_mobilenet_quant_v1_224.txt")
labels={1:'dog',0:'cat'}
cap=cv2.VideoCapture(-1)
while(cap.isOpened()):
ret,frame=cap.read()
frame= cv2.resize(frame,(205*2,154*2))
frame=cv2.flip(frame,0)
RGB_frame=cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
####RGB파일로 변환####
hsv=cv2.cvtColor(frame,cv2.COLOR_BGR2HSV)
red_mask_2=cv2.inRange(hsv,(170,100,80),(180,255,255))
####빨간색필터거치기##
bin_mask=red_mask_2/255
bin_mask=bin_mask.astype('uint8')#contour를 찾기위해서는 소스이미지가 단일 채널의 8비트 이여야한다!!!!
####이진화############
RGB_frame_copy=RGB_frame.copy()
contours, hierarchy = cv2.findContours(bin_mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)##컨투어 찾기
cv2.drawContours(RGB_frame_copy, contours, -1, (0,255,0), 3)
c0=max(contours, key = cv2.contourArea)
####가장큰컨투어찾기##
x0, y0 = zip(*np.squeeze(c0))
x, y, w, h = cv2.boundingRect(c0)
####컨투어박스치기####
target_image=0
target_image=inv_contour(RGB_frame,red_mask_2,x,y,w,h)
####강아지만 자르기###
test_data=cv2.resize(target_image,(150,150))
dic['input']()[0][:,:]=test_data
dic['model'].invoke()
ans=labels[np.argmax(dic['output']()[0])]
####추론#############
result_image= cv2.rectangle(frame, (x, y), (x+w, y+h), 7)
BGR_frame=cv2.cvtColor(RGB_frame_copy,cv2.COLOR_RGB2BGR)
cv2.putText(BGR_frame,ans,(0,100), cv2.FONT_ITALIC, 1, (0, 0, 0))
if(ret):
cv2.imshow('contours',BGR_frame)
cv2.imshow('cutimage',result_image)
cv2.imshow('result',target_image)
k=cv2.waitKey(1)&0xFF
if(k==27):
break
cap.release()
cv2.destroyAllWindows()
####그리기############