Skip to content

Commit

Permalink
[Relay][Keras][Bugfix] fix the converters of GRU and SimpleRNN about …
Browse files Browse the repository at this point in the history
…the go_backwards attribute (#15829)

* fix bug in gru and simpleRNN about go_backwards

* Update test_forward.py

* Update keras.py
  • Loading branch information
jikechao authored Sep 29, 2023
1 parent def551d commit 2890899
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,8 @@ def _convert_simple_rnn(
in_bias = etab.new_const(weightList[2])
assert len(in_data.type_annotation.shape) == 3
timeDim = in_data.type_annotation.shape[1].value
if keras_layer.go_backwards:
in_data = _op.reverse(in_data, axis=1)
in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1)
for i in range(len(in_data_split)):
in_data_split_i = _op.nn.batch_flatten(in_data_split[i])
Expand Down Expand Up @@ -1090,6 +1092,8 @@ def _convert_gru(
recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
if keras_layer.use_bias:
in_bias = etab.new_const(weightList[2])
if keras_layer.go_backwards:
in_data = _op.reverse(in_data, axis=1)
units = list(weightList[0].shape)[1]
assert units > 0, "The value of units must be a positive integer"
in_data = _op.nn.batch_flatten(in_data)
Expand Down
12 changes: 12 additions & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,12 +568,23 @@ def test_forward_rnn(self, keras_mod):
keras_mod.layers.SimpleRNN(
units=16, return_state=False, activation="tanh", use_bias=False
),
keras_mod.layers.SimpleRNN(
units=16, return_state=False, activation="tanh", go_backwards=True
),
keras_mod.layers.GRU(
units=16,
return_state=False,
recurrent_activation="sigmoid",
activation="tanh",
reset_after=False,
),
keras_mod.layers.GRU(
units=16,
return_state=False,
recurrent_activation="sigmoid",
activation="tanh",
reset_after=False,
use_bias=False,
),
keras_mod.layers.GRU(
units=16,
Expand All @@ -582,6 +593,7 @@ def test_forward_rnn(self, keras_mod):
activation="tanh",
reset_after=False,
use_bias=False,
go_backwards=True,
),
]
for rnn_func in rnn_funcs:
Expand Down

0 comments on commit 2890899

Please sign in to comment.