Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Ignoring masking layer test for RNN with MXNet backend (#229)
Browse files Browse the repository at this point in the history
* Ignoring masking layer test for RNN with MXNet backend

* Ignoring the test with unroll=False and masking layer enabled

* Modified comments
  • Loading branch information
karan6181 authored and roywei committed Mar 29, 2019
1 parent aad5bf9 commit 06b4848
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions tests/keras/layers/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ def test_statefulness(layer_class):


@rnn_test
@pytest.mark.skipif(K.backend() == 'mxnet',
reason='MXNet backend has an issue with Masking layer with `unroll=False`'
'in RNN/LSTM/GRU layer.'
'Tracking with this issue:'
'https://github.com/awslabs/keras-apache-mxnet/issues/228')
def test_masking_correctness(layer_class):
# Check masking: output with left padding and right padding
# should be the same.
Expand Down Expand Up @@ -239,12 +244,16 @@ def test_masking_layer():
inputs = np.random.random((6, 3, 4))
targets = np.abs(np.random.random((6, 3, 5)))
targets /= targets.sum(axis=-1, keepdims=True)

model = Sequential()
model.add(Masking(input_shape=(3, 4)))
model.add(recurrent.SimpleRNN(units=5, return_sequences=True, unroll=False))
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(inputs, targets, epochs=1, batch_size=100, verbose=1)
if K.backend() != 'mxnet':
# MXNet backend has an issue with Masking layer with `unroll=False`
# in RNN/LSTM/GRU layer.
# Tracking with this issue:
# https://github.com/awslabs/keras-apache-mxnet/issues/228
model = Sequential()
model.add(Masking(input_shape=(3, 4)))
model.add(recurrent.SimpleRNN(units=5, return_sequences=True, unroll=False))
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(inputs, targets, epochs=1, batch_size=100, verbose=1)

model = Sequential()
model.add(Masking(input_shape=(3, 4)))
Expand Down

0 comments on commit 06b4848

Please sign in to comment.