Skip to content

Commit

Permalink
test: remove unnecessary Enzyme runtime API
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 6, 2024
1 parent 15b20e4 commit 51956a9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 27 deletions.
2 changes: 2 additions & 0 deletions docs/src/introduction/updating_to_v1.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,5 @@ abstraction.
align with Pytorch. Both are controlled using `init_bias` and `use_bias`.
- [`ConvTranspose`](@ref) allows `flipkernel=true` via `cross_correlation=true`. This makes
it efficient for MIOpen.
- Pooling Layers based on lpnorm have been added -- [`LPPool`](@ref),
[`GlobalLPPool`](@ref), and [`AdaptiveLPPool`](@ref).
48 changes: 21 additions & 27 deletions test/layers/recurrent_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
@testitem "RNNCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin
Enzyme.API.runtimeActivity!(true)
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
Expand All @@ -8,7 +7,7 @@
RNNCell(3 => 5, identity; use_bias=false),
RNNCell(3 => 5, identity; use_bias=false, train_state=false))
display(rnncell)
ps, st = Lux.setup(rng, rnncell) .|> dev
ps, st = Lux.setup(rng, rnncell) |> dev
for x_size in ((3, 2), (3,))
x = randn(rng, Float32, x_size...) |> aType
(y, carry), st_ = Lux.apply(rnncell, x, ps, st)
Expand Down Expand Up @@ -36,10 +35,10 @@
RNNCell(3 => 5, identity; use_bias=true, train_state=true))
rnn_no_trainable_state = RNNCell(
3 => 5, identity; use_bias=false, train_state=false)
_ps, _st = Lux.setup(rng, rnn_no_trainable_state) .|> dev
_ps, _st = Lux.setup(rng, rnn_no_trainable_state) |> dev

rnncell = RNNCell(3 => 5, identity; use_bias=false, train_state=true)
ps, st = Lux.setup(rng, rnncell) .|> dev
ps, st = Lux.setup(rng, rnncell) |> dev
ps = merge(_ps, (hidden_state=ps.hidden_state,))

for x_size in ((3, 2), (3,))
Expand All @@ -60,14 +59,13 @@
end

@testitem "LSTMCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin
Enzyme.API.runtimeActivity!(true)
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
for lstmcell in (LSTMCell(3 => 5), LSTMCell(3 => 5; use_bias=true),
LSTMCell(3 => 5; use_bias=false))
display(lstmcell)
ps, st = Lux.setup(rng, lstmcell) .|> dev
ps, st = Lux.setup(rng, lstmcell) |> dev

for x_size in ((3, 2), (3,))
x = randn(rng, Float32, x_size...) |> aType
Expand Down Expand Up @@ -97,12 +95,12 @@ end
x = randn(rng, Float32, x_size...) |> aType
_lstm = LSTMCell(
3 => 5; use_bias=false, train_state=false, train_memory=false)
_ps, _st = Lux.setup(rng, _lstm) .|> dev
_ps, _st = Lux.setup(rng, _lstm) |> dev
(_y, _carry), _ = Lux.apply(_lstm, x, _ps, _st)

lstm = LSTMCell(
3 => 5; use_bias=false, train_state=false, train_memory=false)
ps, st = Lux.setup(rng, lstm) .|> dev
ps, st = Lux.setup(rng, lstm) |> dev
ps = _ps
(y, carry), _ = Lux.apply(lstm, x, ps, st)
@test carry == _carry
Expand All @@ -116,7 +114,7 @@ end

lstm = LSTMCell(
3 => 5; use_bias=false, train_state=true, train_memory=false)
ps, st = Lux.setup(rng, lstm) .|> dev
ps, st = Lux.setup(rng, lstm) |> dev
ps = merge(_ps, (hidden_state=ps.hidden_state,))
(y, carry), _ = Lux.apply(lstm, x, ps, st)
@test carry == _carry
Expand All @@ -130,7 +128,7 @@ end

lstm = LSTMCell(
3 => 5; use_bias=false, train_state=false, train_memory=true)
ps, st = Lux.setup(rng, lstm) .|> dev
ps, st = Lux.setup(rng, lstm) |> dev
ps = merge(_ps, (memory=ps.memory,))
(y, carry), _ = Lux.apply(lstm, x, ps, st)
@test carry == _carry
Expand All @@ -143,7 +141,7 @@ end
@test !isnothing(gs.memory)

lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=true)
ps, st = Lux.setup(rng, lstm) .|> dev
ps, st = Lux.setup(rng, lstm) |> dev
ps = merge(_ps, (hidden_state=ps.hidden_state, memory=ps.memory))
(y, carry), _ = Lux.apply(lstm, x, ps, st)
@test carry == _carry
Expand All @@ -156,7 +154,7 @@ end
@test !isnothing(gs.memory)

lstm = LSTMCell(3 => 5; use_bias=true, train_state=true, train_memory=true)
ps, st = Lux.setup(rng, lstm) .|> dev
ps, st = Lux.setup(rng, lstm) |> dev
ps = merge(_ps, (; ps.bias_ih, ps.bias_hh, ps.hidden_state, ps.memory))
(y, carry), _ = Lux.apply(lstm, x, ps, st)
l, back = Zygote.pullback(
Expand All @@ -172,14 +170,13 @@ end
end

@testitem "GRUCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin
Enzyme.API.runtimeActivity!(true)
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
for grucell in (GRUCell(3 => 5), GRUCell(3 => 5; use_bias=true),
GRUCell(3 => 5; use_bias=false))
display(grucell)
ps, st = Lux.setup(rng, grucell) .|> dev
ps, st = Lux.setup(rng, grucell) |> dev

for x_size in ((3, 2), (3,))
x = randn(rng, Float32, x_size...) |> aType
Expand Down Expand Up @@ -207,11 +204,11 @@ end
for x_size in ((3, 2), (3,))
x = randn(rng, Float32, x_size...) |> aType
_gru = GRUCell(3 => 5; use_bias=false, train_state=false)
_ps, _st = Lux.setup(rng, _gru) .|> dev
_ps, _st = Lux.setup(rng, _gru) |> dev
(_y, _carry), _ = Lux.apply(_gru, x, _ps, _st)

gru = GRUCell(3 => 5; use_bias=false, train_state=false)
ps, st = Lux.setup(rng, gru) .|> dev
ps, st = Lux.setup(rng, gru) |> dev
ps = _ps
(y, carry), _ = Lux.apply(gru, x, ps, st)
@test carry == _carry
Expand All @@ -222,7 +219,7 @@ end
@test_throws ErrorException gs.hidden_state

gru = GRUCell(3 => 5; use_bias=false, train_state=true)
ps, st = Lux.setup(rng, gru) .|> dev
ps, st = Lux.setup(rng, gru) |> dev
ps = merge(_ps, (hidden_state=ps.hidden_state,))
(y, carry), _ = Lux.apply(gru, x, ps, st)
@test carry == _carry
Expand All @@ -231,7 +228,7 @@ end
@test !isnothing(gs.hidden_state)

gru = GRUCell(3 => 5; use_bias=true, train_state=true)
ps, st = Lux.setup(rng, gru) .|> dev
ps, st = Lux.setup(rng, gru) |> dev
ps = merge(_ps, (; ps.bias_ih, ps.bias_hh, ps.hidden_state))
(y, carry), _ = Lux.apply(gru, x, ps, st)
@test carry == _carry
Expand All @@ -244,7 +241,6 @@ end
end

@testitem "StatefulRecurrentCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin
Enzyme.API.runtimeActivity!(true)
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
Expand All @@ -258,7 +254,7 @@ end

for x_size in ((3, 2), (3,))
x = randn(rng, Float32, x_size...) |> aType
ps, st = Lux.setup(rng, rnn) .|> dev
ps, st = Lux.setup(rng, rnn) |> dev

y, st_ = rnn(x, ps, st)

Expand Down Expand Up @@ -292,7 +288,6 @@ end
end

@testitem "Recurrence" setup=[SharedTestSetup] tags=[:recurrent_layers] begin
Enzyme.API.runtimeActivity!(true)
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
Expand All @@ -318,7 +313,7 @@ end
(ntuple(identity, ndims(x) - 2)..., ndims(x), ndims(x) - 1))
end

ps, st = Lux.setup(rng, rnn) .|> dev
ps, st = Lux.setup(rng, rnn) |> dev
y, st_ = rnn(x, ps, st)
y_, st__ = rnn_seq(x, ps, st)

Expand All @@ -343,7 +338,7 @@ end
randn(rng, Float32, 3, 4) |> aType,
Tuple(randn(rng, Float32, 3) for _ in 1:4) .|> aType,
[randn(rng, Float32, 3) for _ in 1:4] .|> aType)
ps, st = Lux.setup(rng, rnn) .|> dev
ps, st = Lux.setup(rng, rnn) |> dev
y, st_ = rnn(x, ps, st)
y_, st__ = rnn_seq(x, ps, st)

Expand Down Expand Up @@ -383,7 +378,7 @@ end
init_state=(rng, args...; kwargs...) -> zeros(args...; kwargs...),
init_bias=(rng, args...; kwargs...) -> zeros(args...; kwargs...));
return_sequence=true)
ps, st = Lux.setup(rng, encoder) .|> dev
ps, st = Lux.setup(rng, encoder) |> dev
m2 = reshape([0.5, 0.0, 0.7, 0.8], 1, :, 1) |> aType
res, _ = encoder(m2, ps, st)

Expand All @@ -392,7 +387,6 @@ end
end

@testitem "Bidirectional" setup=[SharedTestSetup] tags=[:recurrent_layers] begin
Enzyme.API.runtimeActivity!(true)
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
Expand All @@ -404,7 +398,7 @@ end

# Batched Time Series
x = randn(rng, Float32, 3, 4, 2) |> aType
ps, st = Lux.setup(rng, bi_rnn) .|> dev
ps, st = Lux.setup(rng, bi_rnn) |> dev
y, st_ = bi_rnn(x, ps, st)
y_, st__ = bi_rnn_no_merge(x, ps, st)

Expand Down Expand Up @@ -440,7 +434,7 @@ end

# Batched Time Series
x = randn(rng, Float32, 3, 4, 2) |> aType
ps, st = Lux.setup(rng, bi_rnn) .|> dev
ps, st = Lux.setup(rng, bi_rnn) |> dev
y, st_ = bi_rnn(x, ps, st)
y_, st__ = bi_rnn_no_merge(x, ps, st)

Expand Down

0 comments on commit 51956a9

Please sign in to comment.