Skip to content

Commit

Permalink
tested test_inference in workspace
Browse files Browse the repository at this point in the history
  • Loading branch information
AmeyaWagh committed May 29, 2018
1 parent 0158d46 commit 2816315
Showing 2 changed files with 38 additions and 23 deletions.
19 changes: 11 additions & 8 deletions inference.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
import matplotlib
import matplotlib.pyplot as plt
import skimage.io
from imgaug import augmenters as iaa
# from imgaug import augmenters as iaa

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

@@ -118,19 +118,22 @@ def segment_images(original_image):
color_id=1
# print('id:',_id)
mask_1 = f_mask[:,:,ch]
print(mask_1)
mask1 = np.dstack([mask_1*colors[color_id][0],
mask_1*colors[color_id][1],
mask_1*colors[color_id][2]])
final_img = cv2.addWeighted(final_img, 1, mask1.astype(np.uint8), 1, 0)
return final_img

for image_id in range(900,1000):

original_image = cv2.imread('./Train/CameraRGB/{}.png'.format(image_id))[:,:,::-1]

final_img = segment_images(original_image)

cv2.imshow('output', final_img[:,:,::-1])
cv2.waitKey(1)
try:
original_image = cv2.imread('./Train/CameraRGB/{}.png'.format(image_id))[:,:,::-1]

final_img = segment_images(original_image)

cv2.imshow('output', final_img[:,:,::-1])
cv2.waitKey(5)
except KeyboardInterrupt as e:
break

exit()
42 changes: 27 additions & 15 deletions test_inference.py
Original file line number Diff line number Diff line change
@@ -8,11 +8,11 @@

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

ROOT_DIR = os.path.abspath("./")
MODEL_DIR = os.path.join('./', "logs")
ROOT_DIR = os.path.abspath("./Lyft_challenge")
MODEL_DIR = os.path.join('./Lyft_challenge', "logs")

sys.path.append(ROOT_DIR) # To find local version of the library
sys.path.append(os.path.join(os.getcwd(),"./Mask_RCNN/"))
sys.path.append(os.path.join(os.getcwd(),"./Lyft_challenge/Mask_RCNN/"))

from mrcnn.config import Config
from mrcnn import utils
@@ -58,7 +58,7 @@ class ShapesConfig(Config):


config = ShapesConfig()
config.display()
# config.display()


class InferenceConfig(ShapesConfig):
@@ -69,25 +69,35 @@ class InferenceConfig(ShapesConfig):

file = sys.argv[-1]

if file == 'demo.py':
if file == 'test_inference.py':
print ("Error loading video")
quit

model = modellib.MaskRCNN(mode="inference",
config=inference_config,
model_dir=MODEL_DIR)

model_path = os.path.join('./', "mask_rcnn_lyft.h5")
model_path = os.path.join('./Lyft_challenge', "mask_rcnn_lyft.h5")
assert model_path != "", "Provide path to trained weights"
print("Loading weights from ", model_path)
# print("Loading weights from ", model_path)
model.load_weights(model_path, by_name=True)


ROAD = 0
CAR = 1
def segment_image(image_frame):
results = model.detect([image_frame], verbose=0)
r = results[0]
road_mask = r['masks'][:,:,0]
car_mask = r['masks'][:,:,1]
no_ch = r['masks'].shape[2]
if no_ch < 2:
if r["class_ids"]==0:
road_mask = r['masks'][:,:,0]
car_mask = np.zeros(road_mask.shape)
else:
car_mask = r['masks'][:,:,0]
road_mask = np.zeros(car_mask.shape)
else:
road_mask = r['masks'][:,:,0]
car_mask = r['masks'][:,:,1]

return car_mask,road_mask

@@ -110,13 +120,15 @@ def encode(array):
# Grab red channel
# red = rgb_frame[:,:,0]
# Look for red cars :)
# binary_car_result = np.where(red>250,1,0).astype('uint8')
#

# Look for road :)
# binary_road_result = binary_car_result = np.where(red<20,1,0).astype('uint8')
binary_car_result,binary_road_result = segment_image(rgb_frame)

answer_key[frame] = [encode(binary_car_result), encode(binary_road_result)]
#
car_mask,road_mask = segment_image(rgb_frame)
binary_car_result = car_mask*1
binary_road_result = road_mask*1

answer_key[frame] = [encode(binary_car_result.astype('uint8')), encode(binary_road_result.astype('uint8'))]

# Increment frame
frame+=1

0 comments on commit 2816315

Please sign in to comment.