From 3d998aee65aedf462ffa5ccd1e7ade20b618b4a9 Mon Sep 17 00:00:00 2001 From: plliao Date: Mon, 19 Jun 2017 17:42:59 +0800 Subject: [PATCH] fix mask_to_seq bug --- test/wrapper/test_mask_to_seq.py | 31 ++++++++++++++++++++++++++++--- yklz/wrapper/mask_to_seq.py | 2 +- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/test/wrapper/test_mask_to_seq.py b/test/wrapper/test_mask_to_seq.py index 0200153..9317bc7 100644 --- a/test/wrapper/test_mask_to_seq.py +++ b/test/wrapper/test_mask_to_seq.py @@ -42,13 +42,24 @@ def test_output_shape(self): def test_image_data_mask_value(self): result = self.model.predict(self.data) np.testing.assert_almost_equal( - result, + result[:, :self.x_start, :], np.zeros(( self.batch_size, - self.x, + self.x_start, self.y * self.channel_size )) ) + np.testing.assert_almost_equal( + result[:, self.x_end:, :], + np.zeros(( + self.batch_size, + self.x - self.x_end, + self.y * self.channel_size + )) + ) + self.assertTrue( + np.sum(result[:, self.x_start:self.x_end, :], dtype=bool) + ) def test_seq_data_mask_value(self): result = self.model.predict(self.seq_data) @@ -71,7 +82,21 @@ def test_image_data_mask(self): session=K.get_session(), feed_dict={self.model.input: self.data} ) - self.assertFalse(np.any(mask)) + self.assertTrue( + np.all( + mask[:, self.x_start:self.x_end] + ) + ) + self.assertFalse( + np.any( + mask[:, :self.x_start] + ) + ) + self.assertFalse( + np.any( + mask[:, self.x_end:] + ) + ) def test_seq_data_mask(self): mask_cache_key = str(id(self.model.input)) + '_' + str(id(None)) diff --git a/yklz/wrapper/mask_to_seq.py b/yklz/wrapper/mask_to_seq.py index 6b2d905..899f832 100644 --- a/yklz/wrapper/mask_to_seq.py +++ b/yklz/wrapper/mask_to_seq.py @@ -27,7 +27,7 @@ def compute_mask(self, inputs, mask=None): reduce_time = len(mask_shape) - 2 for _ in range(reduce_time): - mask_tensor = K.all(mask_tensor, -1) + mask_tensor = K.any(mask_tensor, -1) return mask_tensor def call(self, inputs, mask=None):