From 28908998e0c55025a89e8e2bd26a3fe3e6c84356 Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Fri, 29 Sep 2023 15:54:23 +0800 Subject: [PATCH] [Relay][Keras][Bugfix] fix the converters of GRU and SimpleRNN about the go_backwards attribute (#15829) * fix bug in gru and simpleRNN about go_backwards * Update test_forward.py * Update keras.py --- python/tvm/relay/frontend/keras.py | 4 ++++ tests/python/frontend/keras/test_forward.py | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 9e09cb400a..6c82ebb427 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -1062,6 +1062,8 @@ def _convert_simple_rnn( in_bias = etab.new_const(weightList[2]) assert len(in_data.type_annotation.shape) == 3 timeDim = in_data.type_annotation.shape[1].value + if keras_layer.go_backwards: + in_data = _op.reverse(in_data, axis=1) in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1) for i in range(len(in_data_split)): in_data_split_i = _op.nn.batch_flatten(in_data_split[i]) @@ -1090,6 +1092,8 @@ def _convert_gru( recurrent_weight = etab.new_const(weightList[1].transpose([1, 0])) if keras_layer.use_bias: in_bias = etab.new_const(weightList[2]) + if keras_layer.go_backwards: + in_data = _op.reverse(in_data, axis=1) units = list(weightList[0].shape)[1] assert units > 0, "The value of units must be a positive integer" in_data = _op.nn.batch_flatten(in_data) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index ba3880e186..8c5b578060 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -568,12 +568,23 @@ def test_forward_rnn(self, keras_mod): keras_mod.layers.SimpleRNN( units=16, return_state=False, activation="tanh", use_bias=False ), + keras_mod.layers.SimpleRNN( + units=16, return_state=False, activation="tanh", go_backwards=True + ), + keras_mod.layers.GRU( + units=16, + return_state=False, + recurrent_activation="sigmoid", + activation="tanh", + reset_after=False, + ), keras_mod.layers.GRU( units=16, return_state=False, recurrent_activation="sigmoid", activation="tanh", reset_after=False, + use_bias=False, ), keras_mod.layers.GRU( units=16, @@ -582,6 +593,7 @@ def test_forward_rnn(self, keras_mod): activation="tanh", reset_after=False, use_bias=False, + go_backwards=True, ), ] for rnn_func in rnn_funcs: