From 7e25fc751bf5f8146ec57bc7fc1001cc37221390 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 22 Sep 2024 18:02:02 -0400 Subject: [PATCH] test: recurrent NN for enzyme --- test/layers/recurrent_tests.jl | 159 ++++++++++++++++----------------- 1 file changed, 75 insertions(+), 84 deletions(-) diff --git a/test/layers/recurrent_tests.jl b/test/layers/recurrent_tests.jl index 78c584cff..dc2978720 100644 --- a/test/layers/recurrent_tests.jl +++ b/test/layers/recurrent_tests.jl @@ -1,11 +1,38 @@ -@testitem "RNNCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin +@testsetup module RecurrentLayersSetup + +using MLDataDevices + +MLDataDevices.get_device_type(::Function) = Nothing # FIXME: upstream maybe? +MLDataDevices.get_device_type(_) = Nothing # FIXME: upstream maybe? + +function loss_loop(cell, x, p, st) + (y, carry), st_ = cell(x, p, st) + for _ in 1:3 + (y, carry), st_ = cell((x, carry), p, st_) + end + return sum(abs2, y) +end + +function loss_loop_no_carry(cell, x, p, st) + y, st_ = cell(x, p, st) + for i in 1:3 + y, st_ = cell(x, p, st_) + end + return sum(abs2, y) +end + +export loss_loop, loss_loop_no_carry + +end + +@testitem "RNNCell" setup=[SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - @testset for rnncell in (RNNCell(3 => 5, identity), RNNCell(3 => 5, tanh), - RNNCell(3 => 5, tanh; use_bias=false), - RNNCell(3 => 5, identity; use_bias=false), - RNNCell(3 => 5, identity; use_bias=false, train_state=false)) + @testset for act in (identity, tanh), use_bias in (true, false), + train_state in (true, false) + + rnncell = RNNCell(3 => 5, act; use_bias, train_state) display(rnncell) ps, st = Lux.setup(rng, rnncell) |> dev @testset for x_size in ((3, 2), (3,)) @@ -15,34 +42,27 @@ @jet rnncell(x, ps, st) @jet rnncell((x, carry), ps, st) - function loss_loop_rnncell(p) - (y, carry), st_ = rnncell(x, p, st) - for _ in 1:10 - (y, carry), st_ = rnncell((x, carry), p, st_) - end - return sum(abs2, y) + if train_state + @test hasproperty(ps, :train_state) + else + @test !hasproperty(ps, :train_state) end - @test_throws ErrorException ps.train_state - - @test_gradients(loss_loop_rnncell, ps; atol=1.0f-3, rtol=1.0f-3, - soft_fail=[AutoFiniteDiff()], broken_backends=[AutoEnzyme()]) + @test_gradients(loss_loop, rnncell, x, ps, st; atol=1.0f-3, rtol=1.0f-3) end end @testset "Trainable hidden states" begin - @testset for rnncell in ( - RNNCell(3 => 5, identity; use_bias=false, train_state=true), - RNNCell(3 => 5, identity; use_bias=true, train_state=true)) + @testset for use_bias in (true, false) + rnncell = RNNCell(3 => 5, identity; use_bias, train_state=true) rnn_no_trainable_state = RNNCell( 3 => 5, identity; use_bias=false, train_state=false) _ps, _st = Lux.setup(rng, rnn_no_trainable_state) |> dev - rnncell = RNNCell(3 => 5, identity; use_bias=false, train_state=true) ps, st = Lux.setup(rng, rnncell) |> dev ps = merge(_ps, (hidden_state=ps.hidden_state,)) - for x_size in ((3, 2), (3,)) + @testset for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType (_y, _carry), _ = Lux.apply(rnn_no_trainable_state, x, _ps, _st) @@ -59,12 +79,12 @@ end end -@testitem "LSTMCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin +@testitem "LSTMCell" setup=[SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - @testset for lstmcell in (LSTMCell(3 => 5), LSTMCell(3 => 5; use_bias=true), - LSTMCell(3 => 5; use_bias=false)) + @testset for use_bias in (true, false) + lstmcell = LSTMCell(3 => 5; use_bias) display(lstmcell) ps, st = Lux.setup(rng, lstmcell) |> dev @@ -75,19 +95,10 @@ end @jet lstmcell(x, ps, st) @jet lstmcell((x, carry), ps, st) - function loss_loop_lstmcell(p) - (y, carry), st_ = lstmcell(x, p, st) - for i in 1:10 - (y, carry), st_ = lstmcell((x, carry), p, st_) - end - return sum(abs2, y) - end - - @test_gradients(loss_loop_lstmcell, ps; atol=1.0f-3, rtol=1.0f-3, - soft_fail=[AutoFiniteDiff()], broken_backends=[AutoEnzyme()]) + @test !hasproperty(ps, :train_state) + @test !hasproperty(ps, :train_memory) - @test_throws ErrorException ps.train_state - @test_throws ErrorException ps.train_memory + @test_gradients(loss_loop, lstmcell, x, ps, st; atol=1.0f-3, rtol=1.0f-3) end end @@ -108,10 +119,11 @@ end l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias_ih - @test_throws ErrorException gs.bias_hh - @test_throws ErrorException gs.hidden_state - @test_throws ErrorException gs.memory + + @test !hasproperty(gs, :bias_ih) + @test !hasproperty(gs, :bias_hh) + @test !hasproperty(gs, :hidden_state) + @test !hasproperty(gs, :memory) lstm = LSTMCell( 3 => 5; use_bias=false, train_state=true, train_memory=false) @@ -122,10 +134,10 @@ end l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias_ih - @test_throws ErrorException gs.bias_hh + @test !hasproperty(gs, :bias_ih) + @test !hasproperty(gs, :bias_hh) @test !isnothing(gs.hidden_state) - @test_throws ErrorException gs.memory + @test !hasproperty(gs, :memory) lstm = LSTMCell( 3 => 5; use_bias=false, train_state=false, train_memory=true) @@ -136,9 +148,9 @@ end l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias_ih - @test_throws ErrorException gs.bias_hh - @test_throws ErrorException gs.hidden_state + @test !hasproperty(gs, :bias_ih) + @test !hasproperty(gs, :bias_hh) + @test !hasproperty(gs, :hidden_state) @test !isnothing(gs.memory) lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=true) @@ -149,8 +161,8 @@ end l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias_ih - @test_throws ErrorException gs.bias_hh + @test !hasproperty(gs, :bias_ih) + @test !hasproperty(gs, :bias_hh) @test !isnothing(gs.hidden_state) @test !isnothing(gs.memory) @@ -170,12 +182,12 @@ end end end -@testitem "GRUCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin +@testitem "GRUCell" setup=[SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - @testset for grucell in (GRUCell(3 => 5), GRUCell(3 => 5; use_bias=true), - GRUCell(3 => 5; use_bias=false)) + @testset for use_bias in (true, false) + grucell = GRUCell(3 => 5; use_bias) display(grucell) ps, st = Lux.setup(rng, grucell) |> dev @@ -186,18 +198,9 @@ end @jet grucell(x, ps, st) @jet grucell((x, carry), ps, st) - function loss_loop_grucell(p) - (y, carry), st_ = grucell(x, p, st) - for i in 1:10 - (y, carry), st_ = grucell((x, carry), p, st_) - end - return sum(abs2, y) - end + @test !hasproperty(ps, :train_state) - @test_gradients(loss_loop_grucell, ps; atol=1e-3, rtol=1e-3, - soft_fail=[AutoFiniteDiff()], broken_backends=[AutoEnzyme()]) - - @test_throws ErrorException ps.train_state + @test_gradients(loss_loop, grucell, x, ps, st; atol=1e-3, rtol=1e-3) end end @@ -215,9 +218,10 @@ end @test carry == _carry l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias_ih - @test_throws ErrorException gs.bias_hh - @test_throws ErrorException gs.hidden_state + + @test !hasproperty(gs, :bias_ih) + @test !hasproperty(gs, :bias_hh) + @test !hasproperty(gs, :hidden_state) gru = GRUCell(3 => 5; use_bias=false, train_state=true) ps, st = Lux.setup(rng, gru) |> dev @@ -241,19 +245,18 @@ end end end -@testitem "StatefulRecurrentCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin +@testitem "StatefulRecurrentCell" setup=[SharedTestSetup, RecurrentLayersSetup] tags=[:recurrent_layers] begin rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - for _cell in (RNNCell, LSTMCell, GRUCell), - use_bias in (true, false), + @testset for _cell in (RNNCell, LSTMCell, GRUCell), use_bias in (true, false), train_state in (true, false) cell = _cell(3 => 5; use_bias, train_state) rnn = StatefulRecurrentCell(cell) display(rnn) - for x_size in ((3, 2), (3,)) + @testset for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType ps, st = Lux.setup(rng, rnn) |> dev @@ -273,16 +276,7 @@ end st__ = Lux.update_state(st, :carry, nothing) @test st__.carry === nothing - function loss_loop_rnn(p) - y, st_ = rnn(x, p, st) - for i in 1:10 - y, st_ = rnn(x, p, st_) - end - return sum(abs2, y) - end - - @test_gradients(loss_loop_rnn, ps; atol=1e-3, rtol=1e-3, - broken_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()]) + @test_gradients(loss_loop_no_carry, rnn, x, ps, st; atol=1e-3, rtol=1e-3) end end end @@ -292,12 +286,9 @@ end rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - @testset "ordering: $ordering" for ordering in (BatchLastIndex(), TimeLastIndex()) - @testset "cell: $_cell" for _cell in (RNNCell, LSTMCell, GRUCell) - @testset "use_bias: $use_bias, train_state: $train_state" for use_bias in ( - true, false), - train_state in (true, false) - + @testset for ordering in (BatchLastIndex(), TimeLastIndex()) + @testset for _cell in (RNNCell, LSTMCell, GRUCell) + @testset for use_bias in (true, false), train_state in (true, false) cell = _cell(3 => 5; use_bias, train_state) rnn = Recurrence(cell; ordering) rnn_seq = Recurrence(cell; ordering, return_sequence=true)