Skip to content

Commit

Permalink
fix mask_to_seq bug
Browse files Browse the repository at this point in the history
  • Loading branch information
plliao committed Jun 19, 2017
1 parent dbcd736 commit 3d998ae
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
31 changes: 28 additions & 3 deletions test/wrapper/test_mask_to_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,24 @@ def test_output_shape(self):
def test_image_data_mask_value(self):
result = self.model.predict(self.data)
np.testing.assert_almost_equal(
result,
result[:, :self.x_start, :],
np.zeros((
self.batch_size,
self.x,
self.x_start,
self.y * self.channel_size
))
)
np.testing.assert_almost_equal(
result[:, self.x_end:, :],
np.zeros((
self.batch_size,
self.x - self.x_end,
self.y * self.channel_size
))
)
self.assertTrue(
np.sum(result[:, self.x_start:self.x_end, :], dtype=bool)
)

def test_seq_data_mask_value(self):
result = self.model.predict(self.seq_data)
Expand All @@ -71,7 +82,21 @@ def test_image_data_mask(self):
session=K.get_session(),
feed_dict={self.model.input: self.data}
)
self.assertFalse(np.any(mask))
self.assertTrue(
np.all(
mask[:, self.x_start:self.x_end]
)
)
self.assertFalse(
np.any(
mask[:, :self.x_start]
)
)
self.assertFalse(
np.any(
mask[:, self.x_end:]
)
)

def test_seq_data_mask(self):
mask_cache_key = str(id(self.model.input)) + '_' + str(id(None))
Expand Down
2 changes: 1 addition & 1 deletion yklz/wrapper/mask_to_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def compute_mask(self, inputs, mask=None):

reduce_time = len(mask_shape) - 2
for _ in range(reduce_time):
mask_tensor = K.all(mask_tensor, -1)
mask_tensor = K.any(mask_tensor, -1)
return mask_tensor

def call(self, inputs, mask=None):
Expand Down

0 comments on commit 3d998ae

Please sign in to comment.