From 8cf451b4225b31d99898d4c7aab540d9b5832fb2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 23:23:47 -0400 Subject: [PATCH] test: minor test fixes --- test/Project.toml | 2 +- test/contrib/freeze_tests.jl | 6 +- test/helpers/compact_tests.jl | 9 +- test/layers/basic_tests.jl | 4 +- test/layers/normalize_tests.jl | 160 +++++++++++++++------------------ test/layers/recurrent_tests.jl | 18 ++-- 6 files changed, 85 insertions(+), 114 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 7663a0d3d..998901f64 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -57,7 +57,7 @@ LinearAlgebra = "1.10" Logging = "1.10" LuxCore = "1.0" LuxLib = "1.3" -LuxTestUtils = "1.2.1" +LuxTestUtils = "1.3" MLDataDevices = "1.1" MLUtils = "0.4.3" NNlib = "0.9.24" diff --git a/test/contrib/freeze_tests.jl b/test/contrib/freeze_tests.jl index aa2eafc1f..fd713a34d 100644 --- a/test/contrib/freeze_tests.jl +++ b/test/contrib/freeze_tests.jl @@ -35,7 +35,8 @@ @jet m(x, ps_c, st) __f = (x, ps) -> sum(first(m(x, ps, st))) - @test_gradients(__f, x, ps_c; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(__f, x, ps_c; atol=1.0f-3, rtol=1.0f-3, + enzyme_set_runtime_activity=true) end @testset "LuxDL/Lux.jl#427" begin @@ -84,7 +85,8 @@ end @jet fd(x, ps, st) __f = (x, ps) -> sum(first(fd(x, ps, st))) - @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, + enzyme_set_runtime_activity=true) fd = Lux.Experimental.freeze(d, ()) @test fd === d diff --git a/test/helpers/compact_tests.jl b/test/helpers/compact_tests.jl index 13e799d4b..49988960b 100644 --- a/test/helpers/compact_tests.jl +++ b/test/helpers/compact_tests.jl @@ -316,7 +316,7 @@ @compact(W=randn(d_out, d_in), b=zeros(d_out), incr=1) do x y = W * x incr *= 10 - return act.(y .+ b) .+ incr + @return act.(y .+ b) .+ incr end end @@ -329,12 +329,7 @@ @test st_new.incr == 10 _, st_new = model(x, ps, st_new) @test st_new.incr == 100 - - # By default creates a closure so type cannot be inferred - inf_type = Core.Compiler._return_type( - model, Tuple{typeof(x), typeof(ps), typeof(st)}).parameters - @test inf_type[1] === Any - @test inf_type[2] === NamedTuple + @test @inferred(model(x, ps, st)) isa Any function ScaledDense2(; d_in=5, d_out=7, act=relu) @compact(W=randn(d_out, d_in), b=zeros(d_out), incr=1) do x diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 8909b3ccd..3c4164c03 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -173,7 +173,7 @@ end @testitem "Dense StaticArrays" setup=[SharedTestSetup] tags=[:core_layers] begin using StaticArrays, Enzyme, ForwardDiff, ComponentArrays - if LuxTestUtils.ENZYME_TESTING_ENABLED && pkgversion(Enzyme) ≥ v"0.12.36" + if LuxTestUtils.ENZYME_TESTING_ENABLED N = 8 d = Lux.Dense(N => N) ps = (; @@ -186,7 +186,7 @@ end ps -> sum(d(x, ps, (;))[1]) end grad1 = ForwardDiff.gradient(fun, ComponentVector(ps)) - grad2 = Enzyme.gradient(Enzyme.Reverse, fun, ps) + grad2 = Enzyme.gradient(Enzyme.Reverse, fun, ps)[1] @test maximum(abs, grad1 .- ComponentVector(grad2)) < 1e-6 end end diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl index 110ccc9c0..123a5c738 100644 --- a/test/layers/normalize_tests.jl +++ b/test/layers/normalize_tests.jl @@ -1,12 +1,12 @@ @testitem "BatchNorm" setup=[SharedTestSetup] tags=[:normalize_layers] begin rng = StableRNG(12345) - @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "$mode" for (mode, aType, dev, ongpu) in MODES m = BatchNorm(2) x = [1.0f0 3.0f0 5.0f0 2.0f0 4.0f0 6.0f0] |> aType display(m) - ps, st = Lux.setup(rng, m) .|> device + ps, st = Lux.setup(rng, m) |> dev @test Lux.parameterlength(m) == Lux.parameterlength(ps) @test Lux.statelength(m) == Lux.statelength(st) @@ -16,7 +16,7 @@ y, st_ = pullback(m, x, ps, st)[1] st_ = st_ |> CPUDevice() - @test check_approx(Array(y), [-1.22474 0 1.22474; -1.22474 0 1.22474]; atol=1.0e-5) + @test Array(y)≈[-1.22474 0 1.22474; -1.22474 0 1.22474] atol=1.0e-5 # julia> x # 2×3 Array{Float64,2}: # 1.0 3.0 5.0 @@ -29,18 +29,18 @@ # ∴ update rule with momentum: # .1 * 3 + 0 = .3 # .1 * 4 + 0 = .4 - @test check_approx(st_.running_mean, reshape([0.3, 0.4], 2, 1)) + @test st_.running_mean ≈ reshape([0.3, 0.4], 2, 1) # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] # 2×1 Array{Float64,2}: # 1.3 # 1.3 - @test check_approx(st_.running_var, - 0.1 .* var(Array(x); dims=2, corrected=false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0]) + @test st_.running_var ≈ + 0.1 .* var(Array(x); dims=2, corrected=false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0] - st_ = Lux.testmode(st_) |> device + st_ = Lux.testmode(st_) |> dev x_ = m(x, ps, st_)[1] |> CPUDevice() - @test check_approx(x_[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) + @test x_[1]≈(1 .- 0.3) / sqrt(1.3) atol=1.0e-5 @jet m(x, ps, st) __f = (x, ps) -> sum(first(m(x, ps, st))) @@ -51,9 +51,9 @@ m = BatchNorm(2; affine, track_stats=false) x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] |> aType display(m) - ps, st = Lux.setup(rng, m) .|> device + ps, st = Lux.setup(rng, m) |> dev - @jet m(x, ps, st) + @jet m(x, ps, Lux.testmode(st)) if affine __f = (x, ps) -> sum(first(m(x, ps, st))) @@ -70,23 +70,19 @@ x = [1.0f0 3.0f0 5.0f0 2.0f0 4.0f0 6.0f0] |> aType display(m) - ps, st = Lux.setup(rng, m) .|> device - st = Lux.testmode(st) - y, st_ = m(x, ps, st) - @test check_approx( - y, sigmoid.((x .- st_.running_mean) ./ sqrt.(st_.running_var .+ m.epsilon)), - atol=1.0e-7) + ps, st = Lux.setup(rng, m) |> dev - @jet m(x, ps, st) + y, st_ = m(x, ps, Lux.testmode(st)) + @test y ≈ + sigmoid.((x .- st_.running_mean) ./ sqrt.(st_.running_var .+ m.epsilon)) + @jet m(x, ps, Lux.testmode(st)) if affine - st_train = Lux.trainmode(st) - __f = (x, ps) -> sum(first(m(x, ps, st_train))) + __f = (x, ps) -> sum(first(m(x, ps, st))) @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoFiniteDiff()]) else - st_train = Lux.trainmode(st) - __f = x -> sum(first(m(x, ps, st_train))) + __f = x -> sum(first(m(x, ps, st))) @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoFiniteDiff()]) end @@ -94,12 +90,9 @@ m = BatchNorm(32; affine) x = randn(Float32, 416, 416, 32, 1) |> aType display(m) - ps, st = Lux.setup(rng, m) .|> device - st = Lux.testmode(st) - m(x, ps, st) - @test (@allocated m(x, ps, st)) < 100_000_000 - - @jet m(x, ps, st) + ps, st = Lux.setup(rng, m) |> dev + m(x, ps, Lux.testmode(st)) + @jet m(x, ps, Lux.testmode(st)) end end end @@ -107,7 +100,7 @@ end @testitem "GroupNorm" setup=[SharedTestSetup] tags=[:normalize_layers] begin rng = StableRNG(12345) - @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "$mode" for (mode, aType, dev, ongpu) in MODES squeeze(x) = dropdims(x; dims=tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions m = GroupNorm(4, 2) @@ -116,27 +109,25 @@ end display(m) x = Float32.(x) - ps, st = Lux.setup(rng, m) .|> device + ps, st = Lux.setup(rng, m) |> dev @test Lux.parameterlength(m) == Lux.parameterlength(ps) @test Lux.statelength(m) == Lux.statelength(st) @test ps.bias == [0, 0, 0, 0] |> aType # init_bias(32) @test ps.scale == [1, 1, 1, 1] |> aType # init_scale(32) - y, st_ = pullback(m, x, ps, st)[1] - @jet m(x, ps, st) __f = let m = m, x = x, st = st ps -> sum(first(m(x, ps, st))) end - @test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3, enzyme_set_runtime_activity=true) - @testset "affine: $affine" for affine in (true, false) + @testset for affine in (true, false) m = GroupNorm(2, 2; affine) x = rand(rng, Float32, 3, 2, 1) |> aType display(m) - ps, st = Lux.setup(rng, m) .|> device + ps, st = Lux.setup(rng, m) |> dev - @jet m(x, ps, st) + @jet m(x, ps, Lux.testmode(st)) if affine __f = (x, ps) -> sum(first(m(x, ps, st))) @@ -152,11 +143,9 @@ end m = GroupNorm(2, 2, sigmoid; affine) x = randn(rng, Float32, 3, 2, 1) |> aType display(m) - ps, st = Lux.setup(rng, m) .|> device - st = Lux.testmode(st) - y, st_ = m(x, ps, st) - - @jet m(x, ps, st) + ps, st = Lux.setup(rng, m) |> dev + y, st_ = m(x, ps, Lux.testmode(st)) + @jet m(x, ps, Lux.testmode(st)) if affine __f = (x, ps) -> sum(first(m(x, ps, st))) @@ -171,17 +160,9 @@ end m = GroupNorm(32, 16; affine) x = randn(rng, Float32, 416, 416, 32, 1) |> aType display(m) - ps, st = Lux.setup(rng, m) .|> device - st = Lux.testmode(st) - m(x, ps, st) - - @test (@allocated m(x, ps, st)) < 100_000_000 - - if affine - LuxTestUtils.JET.@test_opt target_modules=(LuxLib,) m(x, ps, st) - else - @jet m(x, ps, st) - end + ps, st = Lux.setup(rng, m) |> dev + m(x, ps, Lux.testmode(st)) + @jet m(x, ps, Lux.testmode(st)) end @test_throws ArgumentError GroupNorm(5, 2) @@ -191,7 +172,7 @@ end @testitem "WeightNorm" setup=[SharedTestSetup] tags=[:normalize_layers] begin rng = StableRNG(12345) - @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "$mode" for (mode, aType, dev, ongpu) in MODES @testset "Utils.norm_except" begin z = randn(rng, Float32, 3, 3, 4, 2) |> aType @@ -211,7 +192,7 @@ end wn = WeightNorm(c, (:weight, :bias)) display(wn) - ps, st = Lux.setup(rng, wn) .|> device + ps, st = Lux.setup(rng, wn) |> dev x = randn(rng, Float32, 3, 3, 3, 1) |> aType @jet wn(x, ps, st) @@ -220,7 +201,7 @@ end wn = WeightNorm(c, (:weight,)) display(wn) - ps, st = Lux.setup(rng, wn) .|> device + ps, st = Lux.setup(rng, wn) |> dev x = randn(rng, Float32, 3, 3, 3, 1) |> aType @jet wn(x, ps, st) @@ -229,7 +210,7 @@ end wn = WeightNorm(c, (:weight, :bias), (2, 2)) display(wn) - ps, st = Lux.setup(rng, wn) .|> device + ps, st = Lux.setup(rng, wn) |> dev x = randn(rng, Float32, 3, 3, 3, 1) |> aType @jet wn(x, ps, st) @@ -238,7 +219,7 @@ end wn = WeightNorm(c, (:weight,), (2,)) display(wn) - ps, st = Lux.setup(rng, wn) .|> device + ps, st = Lux.setup(rng, wn) |> dev x = randn(rng, Float32, 3, 3, 3, 1) |> aType @jet wn(x, ps, st) @@ -251,7 +232,7 @@ end wn = WeightNorm(d, (:weight, :bias)) display(wn) - ps, st = Lux.setup(rng, wn) .|> device + ps, st = Lux.setup(rng, wn) |> dev x = randn(rng, Float32, 3, 1) |> aType @jet wn(x, ps, st) @@ -260,7 +241,7 @@ end wn = WeightNorm(d, (:weight,)) display(wn) - ps, st = Lux.setup(rng, wn) .|> device + ps, st = Lux.setup(rng, wn) |> dev x = randn(rng, Float32, 3, 1) |> aType @jet wn(x, ps, st) @@ -269,7 +250,7 @@ end wn = WeightNorm(d, (:weight, :bias), (2, 2)) display(wn) - ps, st = Lux.setup(rng, wn) .|> device + ps, st = Lux.setup(rng, wn) |> dev x = randn(rng, Float32, 3, 1) |> aType @jet wn(x, ps, st) @@ -278,7 +259,7 @@ end wn = WeightNorm(d, (:weight,), (2,)) display(wn) - ps, st = Lux.setup(rng, wn) .|> device + ps, st = Lux.setup(rng, wn) |> dev x = randn(rng, Float32, 3, 1) |> aType @jet wn(x, ps, st) @@ -310,21 +291,21 @@ end @testitem "LayerNorm" setup=[SharedTestSetup] tags=[:normalize_layers] begin rng = StableRNG(12345) - @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "$mode" for (mode, aType, dev, ongpu) in MODES x = randn(rng, Float32, 3, 3, 3, 2) |> aType - for bshape in ((3, 3, 3), (1, 3, 1), (3, 1, 3)) - for affine in (true, false) + @testset for bshape in ((3, 3, 3), (1, 3, 1), (3, 1, 3)) + @testset for affine in (true, false) ln = LayerNorm(bshape; affine) display(ln) - ps, st = Lux.setup(rng, ln) .|> device + ps, st = Lux.setup(rng, ln) |> dev - y, st_ = ln(x, ps, st) + y, st_ = ln(x, ps, Lux.testmode(st)) - @test check_approx(mean(y), 0; atol=1.0f-3, rtol=1.0f-3) - @test check_approx(std(y), 1; atol=1.0f-2, rtol=1.0f-2) + @test mean(y)≈0 atol=1.0f-3 + @test std(y)≈1 atol=1.0f-2 - @jet ln(x, ps, st) + @jet ln(x, ps, Lux.testmode(st)) if affine __f = (x, ps) -> sum(first(ln(x, ps, st))) @@ -336,19 +317,18 @@ end skip_backends=[AutoFiniteDiff()]) end - for act in (sigmoid, tanh) + @testset for act in (sigmoid, tanh) ln = LayerNorm(bshape, act; affine) display(ln) - ps, st = Lux.setup(rng, ln) .|> device + ps, st = Lux.setup(rng, ln) |> dev - y, st_ = ln(x, ps, st) + y, st_ = ln(x, ps, Lux.testmode(st)) - @jet ln(x, ps, st) + @jet ln(x, ps, Lux.testmode(st)) if affine __f = (x, ps) -> sum(first(ln(x, ps, st))) - @test_gradients(__f, x, ps; atol=1.0f-3, - rtol=1.0f-3, + @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoFiniteDiff()]) else __f = x -> sum(first(ln(x, ps, st))) @@ -364,47 +344,47 @@ end @testitem "InstanceNorm" setup=[SharedTestSetup] tags=[:normalize_layers] begin rng = StableRNG(12345) - @testset "$mode" for (mode, aType, device, ongpu) in MODES - for x in (randn(rng, Float32, 3, 3, 3, 2), randn(rng, Float32, 3, 3, 2), + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + @testset "ndims(x) = $(ndims(x))" for x in ( + randn(rng, Float32, 3, 3, 3, 2), randn(rng, Float32, 3, 3, 2), randn(rng, Float32, 3, 3, 3, 3, 2)) x = x |> aType - for affine in (true, false), track_stats in (true, false) + @testset for affine in (true, false), track_stats in (true, false) layer = InstanceNorm(3; affine, track_stats) display(layer) - ps, st = Lux.setup(rng, layer) |> device - - y, st_ = layer(x, ps, st) + ps, st = Lux.setup(rng, layer) |> dev - @jet layer(x, ps, st) + y, st_ = layer(x, ps, Lux.testmode(st)) + @jet layer(x, ps, Lux.testmode(st)) if affine __f = (x, ps) -> sum(first(layer(x, ps, st))) @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, - skip_backends=[AutoFiniteDiff()]) + skip_backends=[AutoFiniteDiff()], enzyme_set_runtime_activity=true) else __f = x -> sum(first(layer(x, ps, st))) @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - skip_backends=[AutoFiniteDiff()]) + skip_backends=[AutoFiniteDiff()], enzyme_set_runtime_activity=true) end - for act in (sigmoid, tanh) + @testset for act in (sigmoid, tanh) layer = InstanceNorm(3, act; affine, track_stats) display(layer) - ps, st = Lux.setup(rng, layer) |> device + ps, st = Lux.setup(rng, layer) |> dev - y, st_ = layer(x, ps, st) - - @jet layer(x, ps, st) + y, st_ = layer(x, ps, Lux.testmode(st)) + @jet layer(x, ps, Lux.testmode(st)) if affine __f = (x, ps) -> sum(first(layer(x, ps, st))) @test_gradients(__f, x, ps; atol=1.0f-3, - rtol=1.0f-3, + rtol=1.0f-3, enzyme_set_runtime_activity=true, skip_backends=[AutoFiniteDiff()]) else __f = x -> sum(first(layer(x, ps, st))) @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - skip_backends=[AutoFiniteDiff()]) + skip_backends=[AutoFiniteDiff()], + enzyme_set_runtime_activity=true) end end end diff --git a/test/layers/recurrent_tests.jl b/test/layers/recurrent_tests.jl index e249e200f..78c584cff 100644 --- a/test/layers/recurrent_tests.jl +++ b/test/layers/recurrent_tests.jl @@ -415,17 +415,13 @@ end @test all(x -> size(x) == (5, 2), y_[1]) __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st))) - @test_gradients(__f, ps; atol=1e-3, - rtol=1e-3, - broken_backends=Sys.isapple() ? [AutoEnzyme()] : []) + @test_gradients(__f, ps; atol=1e-3, rtol=1e-3, broken_backends=[AutoEnzyme()]) __f = p -> begin (y1, y2), st_ = bi_rnn_no_merge(x, p, st) return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2) end - @test_gradients(__f, ps; atol=1e-3, - rtol=1e-3, - broken_backends=Sys.isapple() ? [AutoEnzyme()] : []) + @test_gradients(__f, ps; atol=1e-3, rtol=1e-3, broken_backends=[AutoEnzyme()]) @testset "backward_cell: $_backward_cell" for _backward_cell in ( RNNCell, LSTMCell, GRUCell) @@ -453,17 +449,15 @@ end @test all(x -> size(x) == (5, 2), y_[1]) __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st))) - @test_gradients(__f, ps; atol=1e-3, - rtol=1e-3, - broken_backends=Sys.isapple() ? [AutoEnzyme()] : []) + @test_gradients(__f, ps; atol=1e-3, rtol=1e-3, + broken_backends=[AutoEnzyme()]) __f = p -> begin (y1, y2), st_ = bi_rnn_no_merge(x, p, st) return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2) end - @test_gradients(__f, ps; atol=1e-3, - rtol=1e-3, - broken_backends=Sys.isapple() ? [AutoEnzyme()] : []) + @test_gradients(__f, ps; atol=1e-3, rtol=1e-3, + broken_backends=[AutoEnzyme()]) end end end