diff --git a/test/autodiff/nested_autodiff_tests.jl b/test/autodiff/nested_autodiff_tests.jl index 7f93637e5..113144643 100644 --- a/test/autodiff/nested_autodiff_tests.jl +++ b/test/autodiff/nested_autodiff_tests.jl @@ -12,7 +12,9 @@ function test_nested_ad_input_gradient_jacobian(aType, dev, ongpu, loss_fn, X, m ps, st = Lux.setup(rng, model) |> dev X = aType(X) - l = loss_fn(model, X, ps, st) + l = allow_unstable() do + loss_fn(model, X, ps, st) + end @test l isa Number @test isfinite(l) && !isnan(l) @@ -25,9 +27,11 @@ function test_nested_ad_input_gradient_jacobian(aType, dev, ongpu, loss_fn, X, m !iszero(ComponentArray(∂ps |> cpu_device())) && all(x -> x === nothing || isfinite(x), ComponentArray(∂ps |> cpu_device())) - test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; - atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoForwardDiff()], - skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) + allow_unstable() do + test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; + atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoForwardDiff()], + skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) + end end const Xs = (randn(rng, Float32, 3, 3, 2, 4), randn(rng, Float32, 2, 4), @@ -133,7 +137,9 @@ function test_nested_ad_parameter_gradient_jacobian(aType, dev, ongpu, loss_fn, st = st |> dev X = aType(X) - l = loss_fn(model, X, ps, st) + l = allow_unstable() do + loss_fn(model, X, ps, st) + end @test l isa Number @test isfinite(l) && !isnan(l) @@ -146,9 +152,11 @@ function test_nested_ad_parameter_gradient_jacobian(aType, dev, ongpu, loss_fn, !iszero(ComponentArray(∂ps |> cpu_device())) && all(x -> x === nothing || isfinite(x), ComponentArray(∂ps |> cpu_device())) - test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; - atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoForwardDiff()], - skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) + allow_unstable() do + test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; + atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoForwardDiff()], + skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) + end end const Xs = (randn(rng, Float32, 3, 3, 2, 4), randn(rng, Float32, 2, 4), diff --git a/test/runtests.jl b/test/runtests.jl index d9f71aec5..fb0950818 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,18 @@ end using Lux +# Type Stability tests fail if run with DispatchDoctor enabled +if "all" in LUX_TEST_GROUP || "core_layers" in LUX_TEST_GROUP + try + # Run in a separate process to load the updated preferences + run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path)) + --startup-file=no --code-coverage=user $(@__DIR__)/zygote_type_stability.jl`) + @test true + catch + @test false + end +end + Lux.set_dispatch_doctor_preferences!(; luxcore="error", luxlib="error") @testset "Load Tests" begin @@ -85,7 +97,7 @@ const RETESTITEMS_NWORKERS = parse( @info "Running tests for group: [$(i)/$(length(LUX_TEST_GROUP))] $tag" ReTestItems.runtests(Lux; tags=(tag == "all" ? nothing : [Symbol(tag)]), - nworkers=RETESTITEMS_NWORKERS, testitem_timeout=1800, retries=2) + nworkers=RETESTITEMS_NWORKERS, testitem_timeout=1800, retries=1) end end @@ -170,17 +182,3 @@ if ("all" in LUX_TEST_GROUP || "others" in LUX_TEST_GROUP) end end end - -# Type Stability tests fail if run with DispatchDoctor enabled -Lux.set_dispatch_doctor_preferences!(; luxcore="disable", luxlib="disable") - -if "all" in LUX_TEST_GROUP || "core_layers" in LUX_TEST_GROUP - try - # Run in a separate process to load the updated preferences - run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path)) - --startup-file=no --code-coverage=user $(@__DIR__)/zygote_type_stability.jl`) - @test true - catch - @test false - end -end diff --git a/test/zygote_type_stability.jl b/test/zygote_type_stability.jl index 30007212d..517ba590f 100644 --- a/test/zygote_type_stability.jl +++ b/test/zygote_type_stability.jl @@ -72,7 +72,6 @@ include("setup_modes.jl") model in model_list, input in inputs - model = maybe_rewrite_to_crosscor(mode, model) ps, st = Lux.setup(rng, model) |> dev x = input |> dev