diff --git a/test/layers/recurrent_tests.jl b/test/layers/recurrent_tests.jl index bd4b72513..dbd40afb0 100644 --- a/test/layers/recurrent_tests.jl +++ b/test/layers/recurrent_tests.jl @@ -338,13 +338,13 @@ end @test size(y) == (5,) @test length(y_) == 4 @test all(x -> size(x) == (5,), y_) - + if x isa AbstractMatrix && ordering isa BatchLastIndex x2 = reshape(x, Val(3)) - + y2, _ = rnn(x2, ps, st) @test y == vec(y2) - + y2_, _ = rnn_seq(x2, ps, st) @test all(x -> x[1] == vec(x[2]), zip(y_, y2_)) end