Skip to content

Commit

Permalink
test: minor test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 22, 2024
1 parent 8ef6606 commit 87e09f9
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 17 deletions.
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ LinearAlgebra = "1.10"
Logging = "1.10"
LuxCore = "1.0"
LuxLib = "1.3"
LuxTestUtils = "1.2.1"
LuxTestUtils = "1.3"
MLDataDevices = "1.1"
MLUtils = "0.4.3"
NNlib = "0.9.24"
Expand Down
6 changes: 4 additions & 2 deletions test/contrib/freeze_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
@jet m(x, ps_c, st)

__f = (x, ps) -> sum(first(m(x, ps, st)))
@test_gradients(__f, x, ps_c; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(__f, x, ps_c; atol=1.0f-3, rtol=1.0f-3,
enzyme_set_runtime_activity=true)
end

@testset "LuxDL/Lux.jl#427" begin
Expand Down Expand Up @@ -84,7 +85,8 @@ end

@jet fd(x, ps, st)
__f = (x, ps) -> sum(first(fd(x, ps, st)))
@test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3,
enzyme_set_runtime_activity=true)

fd = Lux.Experimental.freeze(d, ())
@test fd === d
Expand Down
9 changes: 2 additions & 7 deletions test/helpers/compact_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@
@compact(W=randn(d_out, d_in), b=zeros(d_out), incr=1) do x
y = W * x
incr *= 10
return act.(y .+ b) .+ incr
@return act.(y .+ b) .+ incr
end
end

Expand All @@ -329,12 +329,7 @@
@test st_new.incr == 10
_, st_new = model(x, ps, st_new)
@test st_new.incr == 100

# By default creates a closure so type cannot be inferred
inf_type = Core.Compiler._return_type(
model, Tuple{typeof(x), typeof(ps), typeof(st)}).parameters
@test inf_type[1] === Any
@test inf_type[2] === NamedTuple
@test @inferred(model(x, ps, st)) isa Any

function ScaledDense2(; d_in=5, d_out=7, act=relu)
@compact(W=randn(d_out, d_in), b=zeros(d_out), incr=1) do x
Expand Down
4 changes: 2 additions & 2 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ end
@testitem "Dense StaticArrays" setup=[SharedTestSetup] tags=[:core_layers] begin
using StaticArrays, Enzyme, ForwardDiff, ComponentArrays

if LuxTestUtils.ENZYME_TESTING_ENABLED && pkgversion(Enzyme) v"0.12.36"
if LuxTestUtils.ENZYME_TESTING_ENABLED
N = 8
d = Lux.Dense(N => N)
ps = (;
Expand All @@ -186,7 +186,7 @@ end
ps -> sum(d(x, ps, (;))[1])
end
grad1 = ForwardDiff.gradient(fun, ComponentVector(ps))
grad2 = Enzyme.gradient(Enzyme.Reverse, fun, ps)
grad2 = Enzyme.gradient(Enzyme.Reverse, fun, ps)[1]
@test maximum(abs, grad1 .- ComponentVector(grad2)) < 1e-6
end
end
Expand Down
11 changes: 6 additions & 5 deletions test/layers/normalize_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ end
__f = let m = m, x = x, st = st
ps -> sum(first(m(x, ps, st)))
end
@test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3, enzyme_set_runtime_activity=true)

@testset "affine: $affine" for affine in (true, false)
m = GroupNorm(2, 2; affine)
Expand Down Expand Up @@ -380,11 +380,11 @@ end
if affine
__f = (x, ps) -> sum(first(layer(x, ps, st)))
@test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoFiniteDiff()])
skip_backends=[AutoFiniteDiff()], enzyme_set_runtime_activity=true)
else
__f = x -> sum(first(layer(x, ps, st)))
@test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoFiniteDiff()])
skip_backends=[AutoFiniteDiff()], enzyme_set_runtime_activity=true)
end

for act in (sigmoid, tanh)
Expand All @@ -399,12 +399,13 @@ end
if affine
__f = (x, ps) -> sum(first(layer(x, ps, st)))
@test_gradients(__f, x, ps; atol=1.0f-3,
rtol=1.0f-3,
rtol=1.0f-3, enzyme_set_runtime_activity=true,
skip_backends=[AutoFiniteDiff()])
else
__f = x -> sum(first(layer(x, ps, st)))
@test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoFiniteDiff()])
skip_backends=[AutoFiniteDiff()],
enzyme_set_runtime_activity=true)
end
end
end
Expand Down

0 comments on commit 87e09f9

Please sign in to comment.