Skip to content

Commit

Permalink
minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 31, 2017
1 parent cd77cac commit b087478
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
2 changes: 2 additions & 0 deletions keras_rcnn/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools

import keras.backend
import numpy
import tensorflow
Expand Down
2 changes: 1 addition & 1 deletion keras_rcnn/layers/object_detection/_object_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class ObjectProposal(keras.engine.topology.Layer):
"""

def __init__(self, proposals, **kwargs):
self.output_dim = (None, proposals, 4)
self.output_dim = (None, None, 4)

self.proposals = proposals

Expand Down
8 changes: 8 additions & 0 deletions tests/layers/object_detection/test_anchor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import keras.backend

import keras_rcnn.backend
import keras_rcnn.layers.object_detection


class TestAnchor:
Expand All @@ -18,6 +19,13 @@ def test_call(self, anchor_layer, gt_boxes):

assert n_all_bbox == 1764

# def test_predict(self):
# shape = (None, 5)
#
# x = keras.layers.Input(shape)
#
# y = keras_rcnn.layers.object_detection.Anchor((14, 14), (224, 224))(x)

def test_regression(self, anchor_layer, gt_boxes):
_, y_true, inds_inside, _ = anchor_layer.call(gt_boxes)

Expand Down
10 changes: 2 additions & 8 deletions tests/layers/object_detection/test_object_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,16 @@ def test_call(self):

y = keras.layers.MaxPooling2D(strides=(2, 2))(y)

y = keras.layers.Conv2D(256, **options)(y)
y = keras.layers.Conv2D(256, **options)(y)
y = keras.layers.Conv2D(256, **options)(y)
y = keras.layers.Conv2D(256, **options)(y)

y = keras.layers.MaxPooling2D(strides=(2, 2))(y)

y = keras.layers.Conv2D(512, **options)(y)
y = keras.layers.Conv2D(512, **options)(y)
y = keras.layers.Conv2D(512, **options)(y)
y = keras.layers.Conv2D(512, **options)(y)

y = keras.layers.MaxPooling2D(strides=(2, 2))(y)

y = keras.layers.Conv2D(512, **options)(y)
y = keras.layers.Conv2D(512, **options)(y)
y = keras.layers.Conv2D(512, **options)(y)

y = keras.layers.Conv2D(512, **options)(y)
Expand All @@ -60,11 +54,11 @@ def test_call(self):

model.compile("sgd", "mse")

image = numpy.random.rand(1, 224, 224, 3)
image = numpy.random.rand(1, *shape)

prediction = model.predict(image)

assert prediction.shape == (1, 300, 4)

def test_compute_output_shape(self, object_proposal_layer):
assert object_proposal_layer.compute_output_shape((14, 14)) == (None, 300, 4)
assert object_proposal_layer.compute_output_shape((14, 14)) == (None, None, 4)

0 comments on commit b087478

Please sign in to comment.